1 | //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // The pass prepares IR for legalization: it assigns SPIR-V types to registers |
10 | // and removes intrinsics which holded these types during IR translation. |
11 | // Also it processes constants and registers them in GR to avoid duplication. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "SPIRV.h" |
16 | #include "SPIRVSubtarget.h" |
17 | #include "SPIRVUtils.h" |
18 | #include "llvm/ADT/PostOrderIterator.h" |
19 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" |
20 | #include "llvm/IR/Attributes.h" |
21 | #include "llvm/IR/Constants.h" |
22 | #include "llvm/IR/DebugInfoMetadata.h" |
23 | #include "llvm/IR/IntrinsicsSPIRV.h" |
24 | #include "llvm/Target/TargetIntrinsicInfo.h" |
25 | |
26 | #define DEBUG_TYPE "spirv-prelegalizer" |
27 | |
28 | using namespace llvm; |
29 | |
30 | namespace { |
31 | class SPIRVPreLegalizer : public MachineFunctionPass { |
32 | public: |
33 | static char ID; |
34 | SPIRVPreLegalizer() : MachineFunctionPass(ID) { |
35 | initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry()); |
36 | } |
37 | bool runOnMachineFunction(MachineFunction &MF) override; |
38 | }; |
39 | } // namespace |
40 | |
41 | static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) { |
42 | MachineRegisterInfo &MRI = MF.getRegInfo(); |
43 | DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT; |
44 | SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites; |
45 | for (MachineBasicBlock &MBB : MF) { |
46 | for (MachineInstr &MI : MBB) { |
47 | if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant)) |
48 | continue; |
49 | ToErase.push_back(Elt: &MI); |
50 | auto *Const = |
51 | cast<Constant>(Val: cast<ConstantAsMetadata>( |
52 | Val: MI.getOperand(i: 3).getMetadata()->getOperand(I: 0)) |
53 | ->getValue()); |
54 | if (auto *GV = dyn_cast<GlobalValue>(Val: Const)) { |
55 | Register Reg = GR->find(C: GV, MF: &MF); |
56 | if (!Reg.isValid()) |
57 | GR->add(C: GV, MF: &MF, R: MI.getOperand(i: 2).getReg()); |
58 | else |
59 | RegsAlreadyAddedToDT[&MI] = Reg; |
60 | } else { |
61 | Register Reg = GR->find(C: Const, MF: &MF); |
62 | if (!Reg.isValid()) { |
63 | if (auto *ConstVec = dyn_cast<ConstantDataVector>(Val: Const)) { |
64 | auto *BuildVec = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg()); |
65 | assert(BuildVec && |
66 | BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR); |
67 | for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) { |
68 | // Ensure that OpConstantComposite reuses a constant when it's |
69 | // already created and available in the same machine function. |
70 | Constant *ElemConst = ConstVec->getElementAsConstant(i); |
71 | Register ElemReg = GR->find(C: ElemConst, MF: &MF); |
72 | if (!ElemReg.isValid()) |
73 | GR->add(C: ElemConst, MF: &MF, R: BuildVec->getOperand(i: 1 + i).getReg()); |
74 | else |
75 | BuildVec->getOperand(i: 1 + i).setReg(ElemReg); |
76 | } |
77 | } |
78 | GR->add(C: Const, MF: &MF, R: MI.getOperand(i: 2).getReg()); |
79 | } else { |
80 | RegsAlreadyAddedToDT[&MI] = Reg; |
81 | // This MI is unused and will be removed. If the MI uses |
82 | // const_composite, it will be unused and should be removed too. |
83 | assert(MI.getOperand(2).isReg() && "Reg operand is expected" ); |
84 | MachineInstr *SrcMI = MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg()); |
85 | if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite)) |
86 | ToEraseComposites.push_back(Elt: SrcMI); |
87 | } |
88 | } |
89 | } |
90 | } |
91 | for (MachineInstr *MI : ToErase) { |
92 | Register Reg = MI->getOperand(i: 2).getReg(); |
93 | if (RegsAlreadyAddedToDT.contains(Val: MI)) |
94 | Reg = RegsAlreadyAddedToDT[MI]; |
95 | auto *RC = MRI.getRegClassOrNull(Reg: MI->getOperand(i: 0).getReg()); |
96 | if (!MRI.getRegClassOrNull(Reg) && RC) |
97 | MRI.setRegClass(Reg, RC); |
98 | MRI.replaceRegWith(FromReg: MI->getOperand(i: 0).getReg(), ToReg: Reg); |
99 | MI->eraseFromParent(); |
100 | } |
101 | for (MachineInstr *MI : ToEraseComposites) |
102 | MI->eraseFromParent(); |
103 | } |
104 | |
105 | static void foldConstantsIntoIntrinsics(MachineFunction &MF) { |
106 | SmallVector<MachineInstr *, 10> ToErase; |
107 | MachineRegisterInfo &MRI = MF.getRegInfo(); |
108 | const unsigned AssignNameOperandShift = 2; |
109 | for (MachineBasicBlock &MBB : MF) { |
110 | for (MachineInstr &MI : MBB) { |
111 | if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name)) |
112 | continue; |
113 | unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift; |
114 | while (MI.getOperand(i: NumOp).isReg()) { |
115 | MachineOperand &MOp = MI.getOperand(i: NumOp); |
116 | MachineInstr *ConstMI = MRI.getVRegDef(Reg: MOp.getReg()); |
117 | assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT); |
118 | MI.removeOperand(OpNo: NumOp); |
119 | MI.addOperand(Op: MachineOperand::CreateImm( |
120 | Val: ConstMI->getOperand(i: 1).getCImm()->getZExtValue())); |
121 | if (MRI.use_empty(RegNo: ConstMI->getOperand(i: 0).getReg())) |
122 | ToErase.push_back(Elt: ConstMI); |
123 | } |
124 | } |
125 | } |
126 | for (MachineInstr *MI : ToErase) |
127 | MI->eraseFromParent(); |
128 | } |
129 | |
130 | static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, |
131 | MachineIRBuilder MIB) { |
132 | // Get access to information about available extensions |
133 | const SPIRVSubtarget *ST = |
134 | static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget()); |
135 | SmallVector<MachineInstr *, 10> ToErase; |
136 | for (MachineBasicBlock &MBB : MF) { |
137 | for (MachineInstr &MI : MBB) { |
138 | if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) && |
139 | !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast)) |
140 | continue; |
141 | assert(MI.getOperand(2).isReg()); |
142 | MIB.setInsertPt(MBB&: *MI.getParent(), II: MI); |
143 | ToErase.push_back(Elt: &MI); |
144 | if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) { |
145 | MIB.buildBitcast(Dst: MI.getOperand(i: 0).getReg(), Src: MI.getOperand(i: 2).getReg()); |
146 | continue; |
147 | } |
148 | Register Def = MI.getOperand(i: 0).getReg(); |
149 | Register Source = MI.getOperand(i: 2).getReg(); |
150 | SPIRVType *BaseTy = GR->getOrCreateSPIRVType( |
151 | getMDOperandAsType(N: MI.getOperand(i: 3).getMetadata(), I: 0), MIB); |
152 | SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType( |
153 | BaseType: BaseTy, I&: MI, TII: *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(), |
154 | SClass: addressSpaceToStorageClass(AddrSpace: MI.getOperand(i: 4).getImm(), STI: *ST)); |
155 | |
156 | // If the bitcast would be redundant, replace all uses with the source |
157 | // register. |
158 | if (GR->getSPIRVTypeForVReg(VReg: Source) == AssignedPtrType) { |
159 | MIB.getMRI()->replaceRegWith(FromReg: Def, ToReg: Source); |
160 | } else { |
161 | GR->assignSPIRVTypeToVReg(Type: AssignedPtrType, VReg: Def, MF); |
162 | MIB.buildBitcast(Dst: Def, Src: Source); |
163 | } |
164 | } |
165 | } |
166 | for (MachineInstr *MI : ToErase) |
167 | MI->eraseFromParent(); |
168 | } |
169 | |
170 | // Translating GV, IRTranslator sometimes generates following IR: |
171 | // %1 = G_GLOBAL_VALUE |
172 | // %2 = COPY %1 |
173 | // %3 = G_ADDRSPACE_CAST %2 |
174 | // |
175 | // or |
176 | // |
177 | // %1 = G_ZEXT %2 |
178 | // G_MEMCPY ... %2 ... |
179 | // |
180 | // New registers have no SPIRVType and no register class info. |
181 | // |
182 | // Set SPIRVType for GV, propagate it from GV to other instructions, |
183 | // also set register classes. |
184 | static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR, |
185 | MachineRegisterInfo &MRI, |
186 | MachineIRBuilder &MIB) { |
187 | SPIRVType *SpirvTy = nullptr; |
188 | assert(MI && "Machine instr is expected" ); |
189 | if (MI->getOperand(i: 0).isReg()) { |
190 | Register Reg = MI->getOperand(i: 0).getReg(); |
191 | SpirvTy = GR->getSPIRVTypeForVReg(VReg: Reg); |
192 | if (!SpirvTy) { |
193 | switch (MI->getOpcode()) { |
194 | case TargetOpcode::G_CONSTANT: { |
195 | MIB.setInsertPt(MBB&: *MI->getParent(), II: MI); |
196 | Type *Ty = MI->getOperand(i: 1).getCImm()->getType(); |
197 | SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB); |
198 | break; |
199 | } |
200 | case TargetOpcode::G_GLOBAL_VALUE: { |
201 | MIB.setInsertPt(MBB&: *MI->getParent(), II: MI); |
202 | const GlobalValue *Global = MI->getOperand(i: 1).getGlobal(); |
203 | Type *ElementTy = GR->getDeducedGlobalValueType(Global); |
204 | auto *Ty = TypedPointerType::get(ElementType: ElementTy, |
205 | AddressSpace: Global->getType()->getAddressSpace()); |
206 | SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB); |
207 | break; |
208 | } |
209 | case TargetOpcode::G_ZEXT: { |
210 | if (MI->getOperand(i: 1).isReg()) { |
211 | if (MachineInstr *DefInstr = |
212 | MRI.getVRegDef(Reg: MI->getOperand(i: 1).getReg())) { |
213 | if (SPIRVType *Def = propagateSPIRVType(MI: DefInstr, GR, MRI, MIB)) { |
214 | unsigned CurrentBW = GR->getScalarOrVectorBitWidth(Type: Def); |
215 | unsigned ExpectedBW = |
216 | std::max(a: MRI.getType(Reg).getScalarSizeInBits(), b: CurrentBW); |
217 | unsigned NumElements = GR->getScalarOrVectorComponentCount(Type: Def); |
218 | SpirvTy = GR->getOrCreateSPIRVIntegerType(BitWidth: ExpectedBW, MIRBuilder&: MIB); |
219 | if (NumElements > 1) |
220 | SpirvTy = |
221 | GR->getOrCreateSPIRVVectorType(BaseType: SpirvTy, NumElements, MIRBuilder&: MIB); |
222 | } |
223 | } |
224 | } |
225 | break; |
226 | } |
227 | case TargetOpcode::G_PTRTOINT: |
228 | SpirvTy = GR->getOrCreateSPIRVIntegerType( |
229 | BitWidth: MRI.getType(Reg).getScalarSizeInBits(), MIRBuilder&: MIB); |
230 | break; |
231 | case TargetOpcode::G_TRUNC: |
232 | case TargetOpcode::G_ADDRSPACE_CAST: |
233 | case TargetOpcode::G_PTR_ADD: |
234 | case TargetOpcode::COPY: { |
235 | MachineOperand &Op = MI->getOperand(i: 1); |
236 | MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Reg: Op.getReg()) : nullptr; |
237 | if (Def) |
238 | SpirvTy = propagateSPIRVType(MI: Def, GR, MRI, MIB); |
239 | break; |
240 | } |
241 | default: |
242 | break; |
243 | } |
244 | if (SpirvTy) |
245 | GR->assignSPIRVTypeToVReg(Type: SpirvTy, VReg: Reg, MF&: MIB.getMF()); |
246 | if (!MRI.getRegClassOrNull(Reg)) |
247 | MRI.setRegClass(Reg, &SPIRV::IDRegClass); |
248 | } |
249 | } |
250 | return SpirvTy; |
251 | } |
252 | |
253 | static std::pair<Register, unsigned> |
254 | createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI, |
255 | const SPIRVGlobalRegistry &GR) { |
256 | if (!SpvType) |
257 | SpvType = GR.getSPIRVTypeForVReg(VReg: SrcReg); |
258 | assert(SpvType && "VReg is expected to have SPIRV type" ); |
259 | LLT SrcLLT = MRI.getType(Reg: SrcReg); |
260 | LLT NewT = LLT::scalar(SizeInBits: 32); |
261 | bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat; |
262 | bool IsVectorFloat = |
263 | SpvType->getOpcode() == SPIRV::OpTypeVector && |
264 | GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() == |
265 | SPIRV::OpTypeFloat; |
266 | IsFloat |= IsVectorFloat; |
267 | auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID; |
268 | auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass; |
269 | if (SrcLLT.isPointer()) { |
270 | unsigned PtrSz = GR.getPointerSize(); |
271 | NewT = LLT::pointer(AddressSpace: 0, SizeInBits: PtrSz); |
272 | bool IsVec = SrcLLT.isVector(); |
273 | if (IsVec) |
274 | NewT = LLT::fixed_vector(NumElements: 2, ScalarTy: NewT); |
275 | if (PtrSz == 64) { |
276 | if (IsVec) { |
277 | GetIdOp = SPIRV::GET_vpID64; |
278 | DstClass = &SPIRV::vpID64RegClass; |
279 | } else { |
280 | GetIdOp = SPIRV::GET_pID64; |
281 | DstClass = &SPIRV::pID64RegClass; |
282 | } |
283 | } else { |
284 | if (IsVec) { |
285 | GetIdOp = SPIRV::GET_vpID32; |
286 | DstClass = &SPIRV::vpID32RegClass; |
287 | } else { |
288 | GetIdOp = SPIRV::GET_pID32; |
289 | DstClass = &SPIRV::pID32RegClass; |
290 | } |
291 | } |
292 | } else if (SrcLLT.isVector()) { |
293 | NewT = LLT::fixed_vector(NumElements: 2, ScalarTy: NewT); |
294 | if (IsFloat) { |
295 | GetIdOp = SPIRV::GET_vfID; |
296 | DstClass = &SPIRV::vfIDRegClass; |
297 | } else { |
298 | GetIdOp = SPIRV::GET_vID; |
299 | DstClass = &SPIRV::vIDRegClass; |
300 | } |
301 | } |
302 | Register IdReg = MRI.createGenericVirtualRegister(Ty: NewT); |
303 | MRI.setRegClass(Reg: IdReg, RC: DstClass); |
304 | return {IdReg, GetIdOp}; |
305 | } |
306 | |
307 | // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as |
308 | // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is |
309 | // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty. |
310 | // It's used also in SPIRVBuiltins.cpp. |
311 | // TODO: maybe move to SPIRVUtils. |
312 | namespace llvm { |
313 | Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, |
314 | SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB, |
315 | MachineRegisterInfo &MRI) { |
316 | MachineInstr *Def = MRI.getVRegDef(Reg); |
317 | assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected." ); |
318 | MIB.setInsertPt(MBB&: *Def->getParent(), |
319 | II: (Def->getNextNode() ? Def->getNextNode()->getIterator() |
320 | : Def->getParent()->end())); |
321 | SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB); |
322 | Register NewReg = MRI.createGenericVirtualRegister(Ty: MRI.getType(Reg)); |
323 | if (auto *RC = MRI.getRegClassOrNull(Reg)) { |
324 | MRI.setRegClass(Reg: NewReg, RC); |
325 | } else { |
326 | MRI.setRegClass(NewReg, &SPIRV::IDRegClass); |
327 | MRI.setRegClass(Reg, &SPIRV::IDRegClass); |
328 | } |
329 | GR->assignSPIRVTypeToVReg(Type: SpirvTy, VReg: Reg, MF&: MIB.getMF()); |
330 | // This is to make it convenient for Legalizer to get the SPIRVType |
331 | // when processing the actual MI (i.e. not pseudo one). |
332 | GR->assignSPIRVTypeToVReg(Type: SpirvTy, VReg: NewReg, MF&: MIB.getMF()); |
333 | // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep |
334 | // the flags after instruction selection. |
335 | const uint32_t Flags = Def->getFlags(); |
336 | MIB.buildInstr(SPIRV::ASSIGN_TYPE) |
337 | .addDef(Reg) |
338 | .addUse(NewReg) |
339 | .addUse(GR->getSPIRVTypeID(SpirvTy)) |
340 | .setMIFlags(Flags); |
341 | Def->getOperand(i: 0).setReg(NewReg); |
342 | return NewReg; |
343 | } |
344 | |
345 | void processInstr(MachineInstr &MI, MachineIRBuilder &MIB, |
346 | MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) { |
347 | assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg())); |
348 | MachineInstr &AssignTypeInst = |
349 | *(MRI.use_instr_begin(RegNo: MI.getOperand(i: 0).getReg())); |
350 | auto NewReg = |
351 | createNewIdReg(SpvType: nullptr, SrcReg: MI.getOperand(i: 0).getReg(), MRI, GR: *GR).first; |
352 | AssignTypeInst.getOperand(i: 1).setReg(NewReg); |
353 | MI.getOperand(i: 0).setReg(NewReg); |
354 | MIB.setInsertPt(MBB&: *MI.getParent(), |
355 | II: (MI.getNextNode() ? MI.getNextNode()->getIterator() |
356 | : MI.getParent()->end())); |
357 | for (auto &Op : MI.operands()) { |
358 | if (!Op.isReg() || Op.isDef()) |
359 | continue; |
360 | auto IdOpInfo = createNewIdReg(SpvType: nullptr, SrcReg: Op.getReg(), MRI, GR: *GR); |
361 | MIB.buildInstr(Opcode: IdOpInfo.second).addDef(RegNo: IdOpInfo.first).addUse(RegNo: Op.getReg()); |
362 | Op.setReg(IdOpInfo.first); |
363 | } |
364 | } |
365 | } // namespace llvm |
366 | |
367 | static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, |
368 | MachineIRBuilder MIB) { |
369 | // Get access to information about available extensions |
370 | const SPIRVSubtarget *ST = |
371 | static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget()); |
372 | |
373 | MachineRegisterInfo &MRI = MF.getRegInfo(); |
374 | SmallVector<MachineInstr *, 10> ToErase; |
375 | |
376 | for (MachineBasicBlock *MBB : post_order(G: &MF)) { |
377 | if (MBB->empty()) |
378 | continue; |
379 | |
380 | bool ReachedBegin = false; |
381 | for (auto MII = std::prev(x: MBB->end()), Begin = MBB->begin(); |
382 | !ReachedBegin;) { |
383 | MachineInstr &MI = *MII; |
384 | |
385 | if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) { |
386 | Register Reg = MI.getOperand(i: 1).getReg(); |
387 | MIB.setInsertPt(MBB&: *MI.getParent(), II: MI.getIterator()); |
388 | SPIRVType *BaseTy = GR->getOrCreateSPIRVType( |
389 | getMDOperandAsType(N: MI.getOperand(i: 2).getMetadata(), I: 0), MIB); |
390 | SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType( |
391 | BaseType: BaseTy, I&: MI, TII: *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(), |
392 | SClass: addressSpaceToStorageClass(AddrSpace: MI.getOperand(i: 3).getImm(), STI: *ST)); |
393 | MachineInstr *Def = MRI.getVRegDef(Reg); |
394 | assert(Def && "Expecting an instruction that defines the register" ); |
395 | // G_GLOBAL_VALUE already has type info. |
396 | if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE) |
397 | insertAssignInstr(Reg, Ty: nullptr, SpirvTy: AssignedPtrType, GR, MIB, |
398 | MRI&: MF.getRegInfo()); |
399 | ToErase.push_back(Elt: &MI); |
400 | } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) { |
401 | Register Reg = MI.getOperand(i: 1).getReg(); |
402 | Type *Ty = getMDOperandAsType(N: MI.getOperand(i: 2).getMetadata(), I: 0); |
403 | MachineInstr *Def = MRI.getVRegDef(Reg); |
404 | assert(Def && "Expecting an instruction that defines the register" ); |
405 | // G_GLOBAL_VALUE already has type info. |
406 | if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE) |
407 | insertAssignInstr(Reg, Ty, SpirvTy: nullptr, GR, MIB, MRI&: MF.getRegInfo()); |
408 | ToErase.push_back(Elt: &MI); |
409 | } else if (MI.getOpcode() == TargetOpcode::G_CONSTANT || |
410 | MI.getOpcode() == TargetOpcode::G_FCONSTANT || |
411 | MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR) { |
412 | // %rc = G_CONSTANT ty Val |
413 | // ===> |
414 | // %cty = OpType* ty |
415 | // %rctmp = G_CONSTANT ty Val |
416 | // %rc = ASSIGN_TYPE %rctmp, %cty |
417 | Register Reg = MI.getOperand(i: 0).getReg(); |
418 | if (MRI.hasOneUse(RegNo: Reg)) { |
419 | MachineInstr &UseMI = *MRI.use_instr_begin(RegNo: Reg); |
420 | if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) || |
421 | isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name)) |
422 | continue; |
423 | } |
424 | Type *Ty = nullptr; |
425 | if (MI.getOpcode() == TargetOpcode::G_CONSTANT) |
426 | Ty = MI.getOperand(i: 1).getCImm()->getType(); |
427 | else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT) |
428 | Ty = MI.getOperand(i: 1).getFPImm()->getType(); |
429 | else { |
430 | assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); |
431 | Type *ElemTy = nullptr; |
432 | MachineInstr *ElemMI = MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg()); |
433 | assert(ElemMI); |
434 | |
435 | if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT) |
436 | ElemTy = ElemMI->getOperand(i: 1).getCImm()->getType(); |
437 | else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT) |
438 | ElemTy = ElemMI->getOperand(i: 1).getFPImm()->getType(); |
439 | else |
440 | llvm_unreachable("Unexpected opcode" ); |
441 | unsigned NumElts = |
442 | MI.getNumExplicitOperands() - MI.getNumExplicitDefs(); |
443 | Ty = VectorType::get(ElementType: ElemTy, NumElements: NumElts, Scalable: false); |
444 | } |
445 | insertAssignInstr(Reg, Ty, SpirvTy: nullptr, GR, MIB, MRI); |
446 | } else if (MI.getOpcode() == TargetOpcode::G_TRUNC || |
447 | MI.getOpcode() == TargetOpcode::G_ZEXT || |
448 | MI.getOpcode() == TargetOpcode::G_PTRTOINT || |
449 | MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE || |
450 | MI.getOpcode() == TargetOpcode::COPY || |
451 | MI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) { |
452 | propagateSPIRVType(MI: &MI, GR, MRI, MIB); |
453 | } |
454 | |
455 | if (MII == Begin) |
456 | ReachedBegin = true; |
457 | else |
458 | --MII; |
459 | } |
460 | } |
461 | for (MachineInstr *MI : ToErase) |
462 | MI->eraseFromParent(); |
463 | } |
464 | |
465 | // Defined in SPIRVLegalizerInfo.cpp. |
466 | extern bool isTypeFoldingSupported(unsigned Opcode); |
467 | |
468 | static void processInstrsWithTypeFolding(MachineFunction &MF, |
469 | SPIRVGlobalRegistry *GR, |
470 | MachineIRBuilder MIB) { |
471 | MachineRegisterInfo &MRI = MF.getRegInfo(); |
472 | for (MachineBasicBlock &MBB : MF) { |
473 | for (MachineInstr &MI : MBB) { |
474 | if (isTypeFoldingSupported(Opcode: MI.getOpcode())) |
475 | processInstr(MI, MIB, MRI, GR); |
476 | } |
477 | } |
478 | |
479 | for (MachineBasicBlock &MBB : MF) { |
480 | for (MachineInstr &MI : MBB) { |
481 | // We need to rewrite dst types for ASSIGN_TYPE instrs to be able |
482 | // to perform tblgen'erated selection and we can't do that on Legalizer |
483 | // as it operates on gMIR only. |
484 | if (MI.getOpcode() != SPIRV::ASSIGN_TYPE) |
485 | continue; |
486 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
487 | unsigned Opcode = MRI.getVRegDef(Reg: SrcReg)->getOpcode(); |
488 | if (!isTypeFoldingSupported(Opcode)) |
489 | continue; |
490 | Register DstReg = MI.getOperand(i: 0).getReg(); |
491 | bool IsDstPtr = MRI.getType(Reg: DstReg).isPointer(); |
492 | bool isDstVec = MRI.getType(Reg: DstReg).isVector(); |
493 | if (IsDstPtr || isDstVec) |
494 | MRI.setRegClass(DstReg, &SPIRV::IDRegClass); |
495 | // Don't need to reset type of register holding constant and used in |
496 | // G_ADDRSPACE_CAST, since it breaks legalizer. |
497 | if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(RegNo: DstReg)) { |
498 | MachineInstr &UseMI = *MRI.use_instr_begin(RegNo: DstReg); |
499 | if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) |
500 | continue; |
501 | } |
502 | MRI.setType(VReg: DstReg, Ty: IsDstPtr ? LLT::pointer(AddressSpace: 0, SizeInBits: GR->getPointerSize()) |
503 | : LLT::scalar(SizeInBits: 32)); |
504 | } |
505 | } |
506 | } |
507 | |
508 | // Find basic blocks of the switch and replace registers in spv_switch() by its |
509 | // MBB equivalent. |
510 | static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR, |
511 | MachineIRBuilder MIB) { |
512 | DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB; |
513 | SmallVector<std::pair<MachineInstr *, SmallVector<MachineInstr *, 8>>> |
514 | Switches; |
515 | for (MachineBasicBlock &MBB : MF) { |
516 | MachineRegisterInfo &MRI = MF.getRegInfo(); |
517 | BB2MBB[MBB.getBasicBlock()] = &MBB; |
518 | for (MachineInstr &MI : MBB) { |
519 | if (!isSpvIntrinsic(MI, Intrinsic::spv_switch)) |
520 | continue; |
521 | // Calls to spv_switch intrinsics representing IR switches. |
522 | SmallVector<MachineInstr *, 8> NewOps; |
523 | for (unsigned i = 2; i < MI.getNumOperands(); ++i) { |
524 | Register Reg = MI.getOperand(i).getReg(); |
525 | if (i % 2 == 1) { |
526 | MachineInstr *ConstInstr = getDefInstrMaybeConstant(ConstReg&: Reg, MRI: &MRI); |
527 | NewOps.push_back(Elt: ConstInstr); |
528 | } else { |
529 | MachineInstr *BuildMBB = MRI.getVRegDef(Reg); |
530 | assert(BuildMBB && |
531 | BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR && |
532 | BuildMBB->getOperand(1).isBlockAddress() && |
533 | BuildMBB->getOperand(1).getBlockAddress()); |
534 | NewOps.push_back(Elt: BuildMBB); |
535 | } |
536 | } |
537 | Switches.push_back(Elt: std::make_pair(x: &MI, y&: NewOps)); |
538 | } |
539 | } |
540 | |
541 | SmallPtrSet<MachineInstr *, 8> ToEraseMI; |
542 | for (auto &SwIt : Switches) { |
543 | MachineInstr &MI = *SwIt.first; |
544 | SmallVector<MachineInstr *, 8> &Ins = SwIt.second; |
545 | SmallVector<MachineOperand, 8> NewOps; |
546 | for (unsigned i = 0; i < Ins.size(); ++i) { |
547 | if (Ins[i]->getOpcode() == TargetOpcode::G_BLOCK_ADDR) { |
548 | BasicBlock *CaseBB = |
549 | Ins[i]->getOperand(i: 1).getBlockAddress()->getBasicBlock(); |
550 | auto It = BB2MBB.find(Val: CaseBB); |
551 | if (It == BB2MBB.end()) |
552 | report_fatal_error(reason: "cannot find a machine basic block by a basic " |
553 | "block in a switch statement" ); |
554 | NewOps.push_back(Elt: MachineOperand::CreateMBB(MBB: It->second)); |
555 | MI.getParent()->addSuccessor(Succ: It->second); |
556 | ToEraseMI.insert(Ptr: Ins[i]); |
557 | } else { |
558 | NewOps.push_back( |
559 | Elt: MachineOperand::CreateCImm(CI: Ins[i]->getOperand(i: 1).getCImm())); |
560 | } |
561 | } |
562 | for (unsigned i = MI.getNumOperands() - 1; i > 1; --i) |
563 | MI.removeOperand(OpNo: i); |
564 | for (auto &MO : NewOps) |
565 | MI.addOperand(Op: MO); |
566 | if (MachineInstr *Next = MI.getNextNode()) { |
567 | if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) { |
568 | ToEraseMI.insert(Ptr: Next); |
569 | Next = MI.getNextNode(); |
570 | } |
571 | if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT) |
572 | ToEraseMI.insert(Ptr: Next); |
573 | } |
574 | } |
575 | for (MachineInstr *BlockAddrI : ToEraseMI) |
576 | BlockAddrI->eraseFromParent(); |
577 | } |
578 | |
579 | static bool isImplicitFallthrough(MachineBasicBlock &MBB) { |
580 | if (MBB.empty()) |
581 | return true; |
582 | |
583 | // Branching SPIR-V intrinsics are not detected by this generic method. |
584 | // Thus, we can only trust negative result. |
585 | if (!MBB.canFallThrough()) |
586 | return false; |
587 | |
588 | // Otherwise, we must manually check if we have a SPIR-V intrinsic which |
589 | // prevent an implicit fallthrough. |
590 | for (MachineBasicBlock::reverse_iterator It = MBB.rbegin(), E = MBB.rend(); |
591 | It != E; ++It) { |
592 | if (isSpvIntrinsic(*It, Intrinsic::spv_switch)) |
593 | return false; |
594 | } |
595 | return true; |
596 | } |
597 | |
598 | static void removeImplicitFallthroughs(MachineFunction &MF, |
599 | MachineIRBuilder MIB) { |
600 | // It is valid for MachineBasicBlocks to not finish with a branch instruction. |
601 | // In such cases, they will simply fallthrough their immediate successor. |
602 | for (MachineBasicBlock &MBB : MF) { |
603 | if (!isImplicitFallthrough(MBB)) |
604 | continue; |
605 | |
606 | assert(std::distance(MBB.successors().begin(), MBB.successors().end()) == |
607 | 1); |
608 | MIB.setInsertPt(MBB, II: MBB.end()); |
609 | MIB.buildBr(Dest&: **MBB.successors().begin()); |
610 | } |
611 | } |
612 | |
613 | bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) { |
614 | // Initialize the type registry. |
615 | const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>(); |
616 | SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry(); |
617 | GR->setCurrentFunc(MF); |
618 | MachineIRBuilder MIB(MF); |
619 | addConstantsToTrack(MF, GR); |
620 | foldConstantsIntoIntrinsics(MF); |
621 | insertBitcasts(MF, GR, MIB); |
622 | generateAssignInstrs(MF, GR, MIB); |
623 | processSwitches(MF, GR, MIB); |
624 | processInstrsWithTypeFolding(MF, GR, MIB); |
625 | removeImplicitFallthroughs(MF, MIB); |
626 | |
627 | return true; |
628 | } |
629 | |
630 | INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer" , false, |
631 | false) |
632 | |
633 | char SPIRVPreLegalizer::ID = 0; |
634 | |
635 | FunctionPass *llvm::createSPIRVPreLegalizerPass() { |
636 | return new SPIRVPreLegalizer(); |
637 | } |
638 | |