1//===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the SPIRVGlobalRegistry class,
10// which is used to maintain rich type information required for SPIR-V even
11// after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12// an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13// and supports consistency of constants and global variables.
14//
15//===----------------------------------------------------------------------===//
16
17#include "SPIRVGlobalRegistry.h"
18#include "SPIRV.h"
19#include "SPIRVBuiltins.h"
20#include "SPIRVSubtarget.h"
21#include "SPIRVTargetMachine.h"
22#include "SPIRVUtils.h"
23#include "llvm/ADT/APInt.h"
24#include "llvm/IR/Constants.h"
25#include "llvm/IR/Type.h"
26#include "llvm/Support/Casting.h"
27#include <cassert>
28
29using namespace llvm;
30SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
31 : PointerSize(PointerSize), Bound(0) {}
32
33SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
34 Register VReg,
35 MachineInstr &I,
36 const SPIRVInstrInfo &TII) {
37 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
38 assignSPIRVTypeToVReg(Type: SpirvType, VReg, MF&: *CurMF);
39 return SpirvType;
40}
41
42SPIRVType *
43SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
44 MachineInstr &I,
45 const SPIRVInstrInfo &TII) {
46 SPIRVType *SpirvType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
47 assignSPIRVTypeToVReg(Type: SpirvType, VReg, MF&: *CurMF);
48 return SpirvType;
49}
50
51SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
52 SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
53 const SPIRVInstrInfo &TII) {
54 SPIRVType *SpirvType =
55 getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
56 assignSPIRVTypeToVReg(Type: SpirvType, VReg, MF&: *CurMF);
57 return SpirvType;
58}
59
60SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
61 const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
62 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
63 SPIRVType *SpirvType =
64 getOrCreateSPIRVType(BitWidth: Type, I&: MIRBuilder, TII: AccessQual, SPIRVOPcode: EmitIR);
65 assignSPIRVTypeToVReg(Type: SpirvType, VReg, MF&: MIRBuilder.getMF());
66 return SpirvType;
67}
68
69void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
70 Register VReg,
71 MachineFunction &MF) {
72 VRegToTypeMap[&MF][VReg] = SpirvType;
73}
74
75static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
76 auto &MRI = MIRBuilder.getMF().getRegInfo();
77 auto Res = MRI.createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 32));
78 MRI.setRegClass(Reg: Res, RC: &SPIRV::TYPERegClass);
79 return Res;
80}
81
82static Register createTypeVReg(MachineRegisterInfo &MRI) {
83 auto Res = MRI.createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 32));
84 MRI.setRegClass(Reg: Res, RC: &SPIRV::TYPERegClass);
85 return Res;
86}
87
88SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
89 return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
90 .addDef(createTypeVReg(MIRBuilder));
91}
92
93SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
94 MachineIRBuilder &MIRBuilder,
95 bool IsSigned) {
96 assert(Width <= 64 && "Unsupported integer width!");
97 const SPIRVSubtarget &ST =
98 cast<SPIRVSubtarget>(Val: MIRBuilder.getMF().getSubtarget());
99 if (ST.canUseExtension(
100 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
101 MIRBuilder.buildInstr(SPIRV::OpExtension)
102 .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
103 MIRBuilder.buildInstr(SPIRV::OpCapability)
104 .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
105 } else if (Width <= 8)
106 Width = 8;
107 else if (Width <= 16)
108 Width = 16;
109 else if (Width <= 32)
110 Width = 32;
111 else if (Width <= 64)
112 Width = 64;
113
114 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
115 .addDef(createTypeVReg(MIRBuilder))
116 .addImm(Width)
117 .addImm(IsSigned ? 1 : 0);
118 return MIB;
119}
120
121SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
122 MachineIRBuilder &MIRBuilder) {
123 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
124 .addDef(createTypeVReg(MIRBuilder))
125 .addImm(Width);
126 return MIB;
127}
128
129SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
130 return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
131 .addDef(createTypeVReg(MIRBuilder));
132}
133
134SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
135 SPIRVType *ElemType,
136 MachineIRBuilder &MIRBuilder) {
137 auto EleOpc = ElemType->getOpcode();
138 (void)EleOpc;
139 assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
140 EleOpc == SPIRV::OpTypeBool) &&
141 "Invalid vector element type");
142
143 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
144 .addDef(createTypeVReg(MIRBuilder))
145 .addUse(getSPIRVTypeID(ElemType))
146 .addImm(NumElems);
147 return MIB;
148}
149
150std::tuple<Register, ConstantInt *, bool>
151SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
152 MachineIRBuilder *MIRBuilder,
153 MachineInstr *I,
154 const SPIRVInstrInfo *TII) {
155 const IntegerType *LLVMIntTy;
156 if (SpvType)
157 LLVMIntTy = cast<IntegerType>(Val: getTypeForSPIRVType(Ty: SpvType));
158 else
159 LLVMIntTy = IntegerType::getInt32Ty(C&: CurMF->getFunction().getContext());
160 bool NewInstr = false;
161 // Find a constant in DT or build a new one.
162 ConstantInt *CI = ConstantInt::get(Ty: const_cast<IntegerType *>(LLVMIntTy), V: Val);
163 Register Res = DT.find(C: CI, MF: CurMF);
164 if (!Res.isValid()) {
165 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(Type: SpvType) : 32;
166 // TODO: handle cases where the type is not 32bit wide
167 // TODO: https://github.com/llvm/llvm-project/issues/88129
168 LLT LLTy = LLT::scalar(SizeInBits: 32);
169 Res = CurMF->getRegInfo().createGenericVirtualRegister(Ty: LLTy);
170 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
171 if (MIRBuilder)
172 assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
173 else
174 assignIntTypeToVReg(BitWidth, VReg: Res, I&: *I, TII: *TII);
175 DT.add(C: CI, MF: CurMF, R: Res);
176 NewInstr = true;
177 }
178 return std::make_tuple(args&: Res, args&: CI, args&: NewInstr);
179}
180
181std::tuple<Register, ConstantFP *, bool, unsigned>
182SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
183 MachineIRBuilder *MIRBuilder,
184 MachineInstr *I,
185 const SPIRVInstrInfo *TII) {
186 const Type *LLVMFloatTy;
187 LLVMContext &Ctx = CurMF->getFunction().getContext();
188 unsigned BitWidth = 32;
189 if (SpvType)
190 LLVMFloatTy = getTypeForSPIRVType(Ty: SpvType);
191 else {
192 LLVMFloatTy = Type::getFloatTy(C&: Ctx);
193 if (MIRBuilder)
194 SpvType = getOrCreateSPIRVType(LLVMFloatTy, *MIRBuilder);
195 }
196 bool NewInstr = false;
197 // Find a constant in DT or build a new one.
198 auto *const CI = ConstantFP::get(Context&: Ctx, V: Val);
199 Register Res = DT.find(C: CI, MF: CurMF);
200 if (!Res.isValid()) {
201 if (SpvType)
202 BitWidth = getScalarOrVectorBitWidth(Type: SpvType);
203 // TODO: handle cases where the type is not 32bit wide
204 // TODO: https://github.com/llvm/llvm-project/issues/88129
205 LLT LLTy = LLT::scalar(SizeInBits: 32);
206 Res = CurMF->getRegInfo().createGenericVirtualRegister(Ty: LLTy);
207 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
208 if (MIRBuilder)
209 assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
210 else
211 assignFloatTypeToVReg(BitWidth, VReg: Res, I&: *I, TII: *TII);
212 DT.add(C: CI, MF: CurMF, R: Res);
213 NewInstr = true;
214 }
215 return std::make_tuple(args&: Res, args: CI, args&: NewInstr, args&: BitWidth);
216}
217
218Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
219 SPIRVType *SpvType,
220 const SPIRVInstrInfo &TII,
221 bool ZeroAsNull) {
222 assert(SpvType);
223 ConstantFP *CI;
224 Register Res;
225 bool New;
226 unsigned BitWidth;
227 std::tie(args&: Res, args&: CI, args&: New, args&: BitWidth) =
228 getOrCreateConstFloatReg(Val, SpvType, MIRBuilder: nullptr, I: &I, TII: &TII);
229 // If we have found Res register which is defined by the passed G_CONSTANT
230 // machine instruction, a new constant instruction should be created.
231 if (!New && (!I.getOperand(i: 0).isReg() || Res != I.getOperand(i: 0).getReg()))
232 return Res;
233 MachineInstrBuilder MIB;
234 MachineBasicBlock &BB = *I.getParent();
235 // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
236 if (Val.isPosZero() && ZeroAsNull) {
237 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
238 .addDef(Res)
239 .addUse(getSPIRVTypeID(SpvType));
240 } else {
241 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantF))
242 .addDef(Res)
243 .addUse(getSPIRVTypeID(SpvType));
244 addNumImm(
245 Imm: APInt(BitWidth, CI->getValueAPF().bitcastToAPInt().getZExtValue()),
246 MIB);
247 }
248 const auto &ST = CurMF->getSubtarget();
249 constrainSelectedInstRegOperands(I&: *MIB, TII: *ST.getInstrInfo(),
250 TRI: *ST.getRegisterInfo(), RBI: *ST.getRegBankInfo());
251 return Res;
252}
253
254Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
255 SPIRVType *SpvType,
256 const SPIRVInstrInfo &TII,
257 bool ZeroAsNull) {
258 assert(SpvType);
259 ConstantInt *CI;
260 Register Res;
261 bool New;
262 std::tie(args&: Res, args&: CI, args&: New) =
263 getOrCreateConstIntReg(Val, SpvType, MIRBuilder: nullptr, I: &I, TII: &TII);
264 // If we have found Res register which is defined by the passed G_CONSTANT
265 // machine instruction, a new constant instruction should be created.
266 if (!New && (!I.getOperand(i: 0).isReg() || Res != I.getOperand(i: 0).getReg()))
267 return Res;
268 MachineInstrBuilder MIB;
269 MachineBasicBlock &BB = *I.getParent();
270 if (Val || !ZeroAsNull) {
271 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
272 .addDef(Res)
273 .addUse(getSPIRVTypeID(SpvType));
274 addNumImm(Imm: APInt(getScalarOrVectorBitWidth(Type: SpvType), Val), MIB);
275 } else {
276 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
277 .addDef(Res)
278 .addUse(getSPIRVTypeID(SpvType));
279 }
280 const auto &ST = CurMF->getSubtarget();
281 constrainSelectedInstRegOperands(I&: *MIB, TII: *ST.getInstrInfo(),
282 TRI: *ST.getRegisterInfo(), RBI: *ST.getRegBankInfo());
283 return Res;
284}
285
286Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
287 MachineIRBuilder &MIRBuilder,
288 SPIRVType *SpvType,
289 bool EmitIR) {
290 auto &MF = MIRBuilder.getMF();
291 const IntegerType *LLVMIntTy;
292 if (SpvType)
293 LLVMIntTy = cast<IntegerType>(Val: getTypeForSPIRVType(Ty: SpvType));
294 else
295 LLVMIntTy = IntegerType::getInt32Ty(C&: MF.getFunction().getContext());
296 // Find a constant in DT or build a new one.
297 const auto ConstInt =
298 ConstantInt::get(Ty: const_cast<IntegerType *>(LLVMIntTy), V: Val);
299 Register Res = DT.find(C: ConstInt, MF: &MF);
300 if (!Res.isValid()) {
301 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(Type: SpvType) : 32;
302 LLT LLTy = LLT::scalar(SizeInBits: EmitIR ? BitWidth : 32);
303 Res = MF.getRegInfo().createGenericVirtualRegister(Ty: LLTy);
304 MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
305 assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
306 SPIRV::AccessQualifier::ReadWrite, EmitIR);
307 DT.add(C: ConstInt, MF: &MIRBuilder.getMF(), R: Res);
308 if (EmitIR) {
309 MIRBuilder.buildConstant(Res, Val: *ConstInt);
310 } else {
311 MachineInstrBuilder MIB;
312 if (Val) {
313 assert(SpvType);
314 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
315 .addDef(Res)
316 .addUse(getSPIRVTypeID(SpvType));
317 addNumImm(Imm: APInt(BitWidth, Val), MIB);
318 } else {
319 assert(SpvType);
320 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
321 .addDef(Res)
322 .addUse(getSPIRVTypeID(SpvType));
323 }
324 const auto &Subtarget = CurMF->getSubtarget();
325 constrainSelectedInstRegOperands(I&: *MIB, TII: *Subtarget.getInstrInfo(),
326 TRI: *Subtarget.getRegisterInfo(),
327 RBI: *Subtarget.getRegBankInfo());
328 }
329 }
330 return Res;
331}
332
333Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
334 MachineIRBuilder &MIRBuilder,
335 SPIRVType *SpvType) {
336 auto &MF = MIRBuilder.getMF();
337 auto &Ctx = MF.getFunction().getContext();
338 if (!SpvType) {
339 const Type *LLVMFPTy = Type::getFloatTy(C&: Ctx);
340 SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder);
341 }
342 // Find a constant in DT or build a new one.
343 const auto ConstFP = ConstantFP::get(Context&: Ctx, V: Val);
344 Register Res = DT.find(C: ConstFP, MF: &MF);
345 if (!Res.isValid()) {
346 Res = MF.getRegInfo().createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 32));
347 MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
348 assignSPIRVTypeToVReg(SpirvType: SpvType, VReg: Res, MF);
349 DT.add(C: ConstFP, MF: &MF, R: Res);
350
351 MachineInstrBuilder MIB;
352 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
353 .addDef(Res)
354 .addUse(getSPIRVTypeID(SpvType));
355 addNumImm(Imm: ConstFP->getValueAPF().bitcastToAPInt(), MIB);
356 }
357
358 return Res;
359}
360
361Register SPIRVGlobalRegistry::getOrCreateBaseRegister(Constant *Val,
362 MachineInstr &I,
363 SPIRVType *SpvType,
364 const SPIRVInstrInfo &TII,
365 unsigned BitWidth) {
366 SPIRVType *Type = SpvType;
367 if (SpvType->getOpcode() == SPIRV::OpTypeVector ||
368 SpvType->getOpcode() == SPIRV::OpTypeArray) {
369 auto EleTypeReg = SpvType->getOperand(i: 1).getReg();
370 Type = getSPIRVTypeForVReg(VReg: EleTypeReg);
371 }
372 if (Type->getOpcode() == SPIRV::OpTypeFloat) {
373 SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
374 return getOrCreateConstFP(Val: dyn_cast<ConstantFP>(Val)->getValue(), I,
375 SpvType: SpvBaseType, TII);
376 }
377 assert(Type->getOpcode() == SPIRV::OpTypeInt);
378 SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
379 return getOrCreateConstInt(Val: Val->getUniqueInteger().getSExtValue(), I,
380 SpvType: SpvBaseType, TII);
381}
382
383Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
384 Constant *Val, MachineInstr &I, SPIRVType *SpvType,
385 const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
386 unsigned ElemCnt, bool ZeroAsNull) {
387 // Find a constant vector in DT or build a new one.
388 Register Res = DT.find(C: CA, MF: CurMF);
389 // If no values are attached, the composite is null constant.
390 bool IsNull = Val->isNullValue() && ZeroAsNull;
391 if (!Res.isValid()) {
392 // SpvScalConst should be created before SpvVecConst to avoid undefined ID
393 // error on validation.
394 // TODO: can moved below once sorting of types/consts/defs is implemented.
395 Register SpvScalConst;
396 if (!IsNull)
397 SpvScalConst = getOrCreateBaseRegister(Val, I, SpvType, TII, BitWidth);
398
399 // TODO: handle cases where the type is not 32bit wide
400 // TODO: https://github.com/llvm/llvm-project/issues/88129
401 LLT LLTy = LLT::scalar(SizeInBits: 32);
402 Register SpvVecConst =
403 CurMF->getRegInfo().createGenericVirtualRegister(Ty: LLTy);
404 CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
405 assignSPIRVTypeToVReg(SpirvType: SpvType, VReg: SpvVecConst, MF&: *CurMF);
406 DT.add(C: CA, MF: CurMF, R: SpvVecConst);
407 MachineInstrBuilder MIB;
408 MachineBasicBlock &BB = *I.getParent();
409 if (!IsNull) {
410 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
411 .addDef(SpvVecConst)
412 .addUse(getSPIRVTypeID(SpvType));
413 for (unsigned i = 0; i < ElemCnt; ++i)
414 MIB.addUse(RegNo: SpvScalConst);
415 } else {
416 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
417 .addDef(SpvVecConst)
418 .addUse(getSPIRVTypeID(SpvType));
419 }
420 const auto &Subtarget = CurMF->getSubtarget();
421 constrainSelectedInstRegOperands(I&: *MIB, TII: *Subtarget.getInstrInfo(),
422 TRI: *Subtarget.getRegisterInfo(),
423 RBI: *Subtarget.getRegBankInfo());
424 return SpvVecConst;
425 }
426 return Res;
427}
428
429Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
430 MachineInstr &I,
431 SPIRVType *SpvType,
432 const SPIRVInstrInfo &TII,
433 bool ZeroAsNull) {
434 const Type *LLVMTy = getTypeForSPIRVType(Ty: SpvType);
435 assert(LLVMTy->isVectorTy());
436 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(Val: LLVMTy);
437 Type *LLVMBaseTy = LLVMVecTy->getElementType();
438 assert(LLVMBaseTy->isIntegerTy());
439 auto *ConstVal = ConstantInt::get(Ty: LLVMBaseTy, V: Val);
440 auto *ConstVec =
441 ConstantVector::getSplat(EC: LLVMVecTy->getElementCount(), Elt: ConstVal);
442 unsigned BW = getScalarOrVectorBitWidth(Type: SpvType);
443 return getOrCreateCompositeOrNull(Val: ConstVal, I, SpvType, TII, CA: ConstVec, BitWidth: BW,
444 ElemCnt: SpvType->getOperand(i: 2).getImm(),
445 ZeroAsNull);
446}
447
448Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
449 MachineInstr &I,
450 SPIRVType *SpvType,
451 const SPIRVInstrInfo &TII,
452 bool ZeroAsNull) {
453 const Type *LLVMTy = getTypeForSPIRVType(Ty: SpvType);
454 assert(LLVMTy->isVectorTy());
455 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(Val: LLVMTy);
456 Type *LLVMBaseTy = LLVMVecTy->getElementType();
457 assert(LLVMBaseTy->isFloatingPointTy());
458 auto *ConstVal = ConstantFP::get(Ty: LLVMBaseTy, V: Val);
459 auto *ConstVec =
460 ConstantVector::getSplat(EC: LLVMVecTy->getElementCount(), Elt: ConstVal);
461 unsigned BW = getScalarOrVectorBitWidth(Type: SpvType);
462 return getOrCreateCompositeOrNull(Val: ConstVal, I, SpvType, TII, CA: ConstVec, BitWidth: BW,
463 ElemCnt: SpvType->getOperand(i: 2).getImm(),
464 ZeroAsNull);
465}
466
467Register
468SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
469 SPIRVType *SpvType,
470 const SPIRVInstrInfo &TII) {
471 const Type *LLVMTy = getTypeForSPIRVType(Ty: SpvType);
472 assert(LLVMTy->isArrayTy());
473 const ArrayType *LLVMArrTy = cast<ArrayType>(Val: LLVMTy);
474 Type *LLVMBaseTy = LLVMArrTy->getElementType();
475 auto *ConstInt = ConstantInt::get(Ty: LLVMBaseTy, V: Val);
476 auto *ConstArr =
477 ConstantArray::get(T: const_cast<ArrayType *>(LLVMArrTy), V: {ConstInt});
478 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(VReg: SpvType->getOperand(i: 1).getReg());
479 unsigned BW = getScalarOrVectorBitWidth(Type: SpvBaseTy);
480 return getOrCreateCompositeOrNull(Val: ConstInt, I, SpvType, TII, CA: ConstArr, BitWidth: BW,
481 ElemCnt: LLVMArrTy->getNumElements());
482}
483
484Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
485 uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,
486 Constant *CA, unsigned BitWidth, unsigned ElemCnt) {
487 Register Res = DT.find(C: CA, MF: CurMF);
488 if (!Res.isValid()) {
489 Register SpvScalConst;
490 if (Val || EmitIR) {
491 SPIRVType *SpvBaseType =
492 getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
493 SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvType: SpvBaseType, EmitIR);
494 }
495 LLT LLTy = EmitIR ? LLT::fixed_vector(NumElements: ElemCnt, ScalarSizeInBits: BitWidth) : LLT::scalar(SizeInBits: 32);
496 Register SpvVecConst =
497 CurMF->getRegInfo().createGenericVirtualRegister(Ty: LLTy);
498 CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
499 assignSPIRVTypeToVReg(SpirvType: SpvType, VReg: SpvVecConst, MF&: *CurMF);
500 DT.add(C: CA, MF: CurMF, R: SpvVecConst);
501 if (EmitIR) {
502 MIRBuilder.buildSplatVector(Res: SpvVecConst, Val: SpvScalConst);
503 } else {
504 if (Val) {
505 auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
506 .addDef(SpvVecConst)
507 .addUse(getSPIRVTypeID(SpvType));
508 for (unsigned i = 0; i < ElemCnt; ++i)
509 MIB.addUse(SpvScalConst);
510 } else {
511 MIRBuilder.buildInstr(SPIRV::OpConstantNull)
512 .addDef(SpvVecConst)
513 .addUse(getSPIRVTypeID(SpvType));
514 }
515 }
516 return SpvVecConst;
517 }
518 return Res;
519}
520
521Register
522SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
523 MachineIRBuilder &MIRBuilder,
524 SPIRVType *SpvType, bool EmitIR) {
525 const Type *LLVMTy = getTypeForSPIRVType(Ty: SpvType);
526 assert(LLVMTy->isVectorTy());
527 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(Val: LLVMTy);
528 Type *LLVMBaseTy = LLVMVecTy->getElementType();
529 const auto ConstInt = ConstantInt::get(Ty: LLVMBaseTy, V: Val);
530 auto ConstVec =
531 ConstantVector::getSplat(EC: LLVMVecTy->getElementCount(), Elt: ConstInt);
532 unsigned BW = getScalarOrVectorBitWidth(Type: SpvType);
533 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
534 CA: ConstVec, BitWidth: BW,
535 ElemCnt: SpvType->getOperand(i: 2).getImm());
536}
537
538Register
539SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val,
540 MachineIRBuilder &MIRBuilder,
541 SPIRVType *SpvType, bool EmitIR) {
542 const Type *LLVMTy = getTypeForSPIRVType(Ty: SpvType);
543 assert(LLVMTy->isArrayTy());
544 const ArrayType *LLVMArrTy = cast<ArrayType>(Val: LLVMTy);
545 Type *LLVMBaseTy = LLVMArrTy->getElementType();
546 const auto ConstInt = ConstantInt::get(Ty: LLVMBaseTy, V: Val);
547 auto ConstArr =
548 ConstantArray::get(T: const_cast<ArrayType *>(LLVMArrTy), V: {ConstInt});
549 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(VReg: SpvType->getOperand(i: 1).getReg());
550 unsigned BW = getScalarOrVectorBitWidth(Type: SpvBaseTy);
551 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
552 CA: ConstArr, BitWidth: BW,
553 ElemCnt: LLVMArrTy->getNumElements());
554}
555
556Register
557SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
558 SPIRVType *SpvType) {
559 const Type *LLVMTy = getTypeForSPIRVType(Ty: SpvType);
560 const TypedPointerType *LLVMPtrTy = cast<TypedPointerType>(Val: LLVMTy);
561 // Find a constant in DT or build a new one.
562 Constant *CP = ConstantPointerNull::get(T: PointerType::get(
563 ElementType: LLVMPtrTy->getElementType(), AddressSpace: LLVMPtrTy->getAddressSpace()));
564 Register Res = DT.find(C: CP, MF: CurMF);
565 if (!Res.isValid()) {
566 LLT LLTy = LLT::pointer(AddressSpace: LLVMPtrTy->getAddressSpace(), SizeInBits: PointerSize);
567 Res = CurMF->getRegInfo().createGenericVirtualRegister(Ty: LLTy);
568 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
569 assignSPIRVTypeToVReg(SpirvType: SpvType, VReg: Res, MF&: *CurMF);
570 MIRBuilder.buildInstr(SPIRV::OpConstantNull)
571 .addDef(Res)
572 .addUse(getSPIRVTypeID(SpvType));
573 DT.add(C: CP, MF: CurMF, R: Res);
574 }
575 return Res;
576}
577
578Register SPIRVGlobalRegistry::buildConstantSampler(
579 Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode,
580 MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) {
581 SPIRVType *SampTy;
582 if (SpvType)
583 SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(Ty: SpvType), MIRBuilder);
584 else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t",
585 MIRBuilder)) == nullptr)
586 report_fatal_error(reason: "Unable to recognize SPIRV type name: opencl.sampler_t");
587
588 auto Sampler =
589 ResReg.isValid()
590 ? ResReg
591 : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
592 auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
593 .addDef(Sampler)
594 .addUse(getSPIRVTypeID(SampTy))
595 .addImm(AddrMode)
596 .addImm(Param)
597 .addImm(FilerMode);
598 assert(Res->getOperand(0).isReg());
599 return Res->getOperand(0).getReg();
600}
601
602Register SPIRVGlobalRegistry::buildGlobalVariable(
603 Register ResVReg, SPIRVType *BaseType, StringRef Name,
604 const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
605 const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
606 SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
607 bool IsInstSelector) {
608 const GlobalVariable *GVar = nullptr;
609 if (GV)
610 GVar = cast<const GlobalVariable>(Val: GV);
611 else {
612 // If GV is not passed explicitly, use the name to find or construct
613 // the global variable.
614 Module *M = MIRBuilder.getMF().getFunction().getParent();
615 GVar = M->getGlobalVariable(Name);
616 if (GVar == nullptr) {
617 const Type *Ty = getTypeForSPIRVType(Ty: BaseType); // TODO: check type.
618 // Module takes ownership of the global var.
619 GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
620 GlobalValue::ExternalLinkage, nullptr,
621 Twine(Name));
622 }
623 GV = GVar;
624 }
625 Register Reg = DT.find(GV: GVar, MF: &MIRBuilder.getMF());
626 if (Reg.isValid()) {
627 if (Reg != ResVReg)
628 MIRBuilder.buildCopy(Res: ResVReg, Op: Reg);
629 return ResVReg;
630 }
631
632 auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
633 .addDef(ResVReg)
634 .addUse(getSPIRVTypeID(BaseType))
635 .addImm(static_cast<uint32_t>(Storage));
636
637 if (Init != 0) {
638 MIB.addUse(Init->getOperand(i: 0).getReg());
639 }
640
641 // ISel may introduce a new register on this step, so we need to add it to
642 // DT and correct its type avoiding fails on the next stage.
643 if (IsInstSelector) {
644 const auto &Subtarget = CurMF->getSubtarget();
645 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
646 *Subtarget.getRegisterInfo(),
647 *Subtarget.getRegBankInfo());
648 }
649 Reg = MIB->getOperand(0).getReg();
650 DT.add(GV: GVar, MF: &MIRBuilder.getMF(), R: Reg);
651
652 // Set to Reg the same type as ResVReg has.
653 auto MRI = MIRBuilder.getMRI();
654 assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
655 if (Reg != ResVReg) {
656 LLT RegLLTy =
657 LLT::pointer(AddressSpace: MRI->getType(Reg: ResVReg).getAddressSpace(), SizeInBits: getPointerSize());
658 MRI->setType(VReg: Reg, Ty: RegLLTy);
659 assignSPIRVTypeToVReg(SpirvType: BaseType, VReg: Reg, MF&: MIRBuilder.getMF());
660 } else {
661 // Our knowledge about the type may be updated.
662 // If that's the case, we need to update a type
663 // associated with the register.
664 SPIRVType *DefType = getSPIRVTypeForVReg(VReg: ResVReg);
665 if (!DefType || DefType != BaseType)
666 assignSPIRVTypeToVReg(SpirvType: BaseType, VReg: Reg, MF&: MIRBuilder.getMF());
667 }
668
669 // If it's a global variable with name, output OpName for it.
670 if (GVar && GVar->hasName())
671 buildOpName(Target: Reg, Name: GVar->getName(), MIRBuilder);
672
673 // Output decorations for the GV.
674 // TODO: maybe move to GenerateDecorations pass.
675 const SPIRVSubtarget &ST =
676 cast<SPIRVSubtarget>(Val: MIRBuilder.getMF().getSubtarget());
677 if (IsConst && ST.isOpenCLEnv())
678 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
679
680 if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
681 unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
682 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
683 }
684
685 if (HasLinkageTy)
686 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
687 {static_cast<uint32_t>(LinkageType)}, Name);
688
689 SPIRV::BuiltIn::BuiltIn BuiltInId;
690 if (getSpirvBuiltInIdByName(Name, BuiltInId))
691 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn,
692 {static_cast<uint32_t>(BuiltInId)});
693
694 return Reg;
695}
696
697SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
698 SPIRVType *ElemType,
699 MachineIRBuilder &MIRBuilder,
700 bool EmitIR) {
701 assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
702 "Invalid array element type");
703 Register NumElementsVReg =
704 buildConstantInt(Val: NumElems, MIRBuilder, SpvType: nullptr, EmitIR);
705 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
706 .addDef(createTypeVReg(MIRBuilder))
707 .addUse(getSPIRVTypeID(ElemType))
708 .addUse(NumElementsVReg);
709 return MIB;
710}
711
712SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
713 MachineIRBuilder &MIRBuilder) {
714 assert(Ty->hasName());
715 const StringRef Name = Ty->hasName() ? Ty->getName() : "";
716 Register ResVReg = createTypeVReg(MIRBuilder);
717 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
718 addStringImm(Name, MIB);
719 buildOpName(Target: ResVReg, Name, MIRBuilder);
720 return MIB;
721}
722
723SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
724 MachineIRBuilder &MIRBuilder,
725 bool EmitIR) {
726 SmallVector<Register, 4> FieldTypes;
727 for (const auto &Elem : Ty->elements()) {
728 SPIRVType *ElemTy =
729 findSPIRVType(toTypedPointer(Elem, Ty->getContext()), MIRBuilder);
730 assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
731 "Invalid struct element type");
732 FieldTypes.push_back(Elt: getSPIRVTypeID(SpirvType: ElemTy));
733 }
734 Register ResVReg = createTypeVReg(MIRBuilder);
735 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
736 for (const auto &Ty : FieldTypes)
737 MIB.addUse(Ty);
738 if (Ty->hasName())
739 buildOpName(Target: ResVReg, Name: Ty->getName(), MIRBuilder);
740 if (Ty->isPacked())
741 buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
742 return MIB;
743}
744
745SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
746 const Type *Ty, MachineIRBuilder &MIRBuilder,
747 SPIRV::AccessQualifier::AccessQualifier AccQual) {
748 assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type");
749 return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this);
750}
751
752SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
753 SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,
754 MachineIRBuilder &MIRBuilder, Register Reg) {
755 if (!Reg.isValid())
756 Reg = createTypeVReg(MIRBuilder);
757 return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
758 .addDef(Reg)
759 .addImm(static_cast<uint32_t>(SC))
760 .addUse(getSPIRVTypeID(ElemType));
761}
762
763SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
764 SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
765 return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
766 .addUse(createTypeVReg(MIRBuilder))
767 .addImm(static_cast<uint32_t>(SC));
768}
769
770SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
771 SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
772 MachineIRBuilder &MIRBuilder) {
773 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
774 .addDef(createTypeVReg(MIRBuilder))
775 .addUse(getSPIRVTypeID(RetType));
776 for (const SPIRVType *ArgType : ArgTypes)
777 MIB.addUse(getSPIRVTypeID(SpirvType: ArgType));
778 return MIB;
779}
780
781SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
782 const Type *Ty, SPIRVType *RetType,
783 const SmallVectorImpl<SPIRVType *> &ArgTypes,
784 MachineIRBuilder &MIRBuilder) {
785 Register Reg = DT.find(Ty, MF: &MIRBuilder.getMF());
786 if (Reg.isValid())
787 return getSPIRVTypeForVReg(VReg: Reg);
788 SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
789 DT.add(Ty, MF: CurMF, R: getSPIRVTypeID(SpirvType));
790 return finishCreatingSPIRVType(LLVMTy: Ty, SpirvType);
791}
792
793SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
794 const Type *Ty, MachineIRBuilder &MIRBuilder,
795 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
796 Register Reg = DT.find(Ty, MF: &MIRBuilder.getMF());
797 if (Reg.isValid())
798 return getSPIRVTypeForVReg(VReg: Reg);
799 if (ForwardPointerTypes.contains(Val: Ty))
800 return ForwardPointerTypes[Ty];
801 return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
802}
803
804Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
805 assert(SpirvType && "Attempting to get type id for nullptr type.");
806 if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)
807 return SpirvType->uses().begin()->getReg();
808 return SpirvType->defs().begin()->getReg();
809}
810
811SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
812 const Type *Ty, MachineIRBuilder &MIRBuilder,
813 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
814 if (isSpecialOpaqueType(Ty))
815 return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
816 auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
817 auto t = TypeToSPIRVTypeMap.find(Key: Ty);
818 if (t != TypeToSPIRVTypeMap.end()) {
819 auto tt = t->second.find(Key: &MIRBuilder.getMF());
820 if (tt != t->second.end())
821 return getSPIRVTypeForVReg(VReg: tt->second);
822 }
823
824 if (auto IType = dyn_cast<IntegerType>(Val: Ty)) {
825 const unsigned Width = IType->getBitWidth();
826 return Width == 1 ? getOpTypeBool(MIRBuilder)
827 : getOpTypeInt(Width, MIRBuilder, IsSigned: false);
828 }
829 if (Ty->isFloatingPointTy())
830 return getOpTypeFloat(Width: Ty->getPrimitiveSizeInBits(), MIRBuilder);
831 if (Ty->isVoidTy())
832 return getOpTypeVoid(MIRBuilder);
833 if (Ty->isVectorTy()) {
834 SPIRVType *El =
835 findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
836 return getOpTypeVector(NumElems: cast<FixedVectorType>(Val: Ty)->getNumElements(), ElemType: El,
837 MIRBuilder);
838 }
839 if (Ty->isArrayTy()) {
840 SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
841 return getOpTypeArray(NumElems: Ty->getArrayNumElements(), ElemType: El, MIRBuilder, EmitIR);
842 }
843 if (auto SType = dyn_cast<StructType>(Val: Ty)) {
844 if (SType->isOpaque())
845 return getOpTypeOpaque(Ty: SType, MIRBuilder);
846 return getOpTypeStruct(Ty: SType, MIRBuilder, EmitIR);
847 }
848 if (auto FType = dyn_cast<FunctionType>(Val: Ty)) {
849 SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
850 SmallVector<SPIRVType *, 4> ParamTypes;
851 for (const auto &t : FType->params()) {
852 ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
853 }
854 return getOpTypeFunction(RetType: RetTy, ArgTypes: ParamTypes, MIRBuilder);
855 }
856 unsigned AddrSpace = 0xFFFF;
857 if (auto PType = dyn_cast<TypedPointerType>(Val: Ty))
858 AddrSpace = PType->getAddressSpace();
859 else if (auto PType = dyn_cast<PointerType>(Val: Ty))
860 AddrSpace = PType->getAddressSpace();
861 else
862 report_fatal_error(reason: "Unable to convert LLVM type to SPIRVType", gen_crash_diag: true);
863
864 SPIRVType *SpvElementType = nullptr;
865 if (auto PType = dyn_cast<TypedPointerType>(Val: Ty))
866 SpvElementType = getOrCreateSPIRVType(BitWidth: PType->getElementType(), I&: MIRBuilder,
867 TII: AccQual, SPIRVOPcode: EmitIR);
868 else
869 SpvElementType = getOrCreateSPIRVIntegerType(BitWidth: 8, MIRBuilder);
870
871 // Get access to information about available extensions
872 const SPIRVSubtarget *ST =
873 static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
874 auto SC = addressSpaceToStorageClass(AddrSpace, STI: *ST);
875 // Null pointer means we have a loop in type definitions, make and
876 // return corresponding OpTypeForwardPointer.
877 if (SpvElementType == nullptr) {
878 if (!ForwardPointerTypes.contains(Val: Ty))
879 ForwardPointerTypes[Ty] = getOpTypeForwardPointer(SC, MIRBuilder);
880 return ForwardPointerTypes[Ty];
881 }
882 // If we have forward pointer associated with this type, use its register
883 // operand to create OpTypePointer.
884 if (ForwardPointerTypes.contains(Val: Ty)) {
885 Register Reg = getSPIRVTypeID(SpirvType: ForwardPointerTypes[Ty]);
886 return getOpTypePointer(SC, ElemType: SpvElementType, MIRBuilder, Reg);
887 }
888
889 return getOrCreateSPIRVPointerType(BaseType: SpvElementType, MIRBuilder, SClass: SC);
890}
891
892SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
893 const Type *Ty, MachineIRBuilder &MIRBuilder,
894 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
895 if (TypesInProcessing.count(Ptr: Ty) && !isPointerTy(T: Ty))
896 return nullptr;
897 TypesInProcessing.insert(Ptr: Ty);
898 SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
899 TypesInProcessing.erase(Ptr: Ty);
900 VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
901 SPIRVToLLVMType[SpirvType] = Ty;
902 Register Reg = DT.find(Ty, MF: &MIRBuilder.getMF());
903 // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
904 // will be added later. For special types it is already added to DT.
905 if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
906 !isSpecialOpaqueType(Ty)) {
907 if (!isPointerTy(T: Ty))
908 DT.add(Ty, MF: &MIRBuilder.getMF(), R: getSPIRVTypeID(SpirvType));
909 else if (isTypedPointerTy(T: Ty))
910 DT.add(PointerElementType: cast<TypedPointerType>(Val: Ty)->getElementType(),
911 AddressSpace: getPointerAddressSpace(T: Ty), MF: &MIRBuilder.getMF(),
912 R: getSPIRVTypeID(SpirvType));
913 else
914 DT.add(PointerElementType: Type::getInt8Ty(C&: MIRBuilder.getMF().getFunction().getContext()),
915 AddressSpace: getPointerAddressSpace(T: Ty), MF: &MIRBuilder.getMF(),
916 R: getSPIRVTypeID(SpirvType));
917 }
918
919 return SpirvType;
920}
921
922SPIRVType *
923SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
924 const MachineFunction *MF) const {
925 auto t = VRegToTypeMap.find(Val: MF ? MF : CurMF);
926 if (t != VRegToTypeMap.end()) {
927 auto tt = t->second.find(Val: VReg);
928 if (tt != t->second.end())
929 return tt->second;
930 }
931 return nullptr;
932}
933
934SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
935 const Type *Ty, MachineIRBuilder &MIRBuilder,
936 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
937 Register Reg;
938 if (!isPointerTy(T: Ty))
939 Reg = DT.find(Ty, MF: &MIRBuilder.getMF());
940 else if (isTypedPointerTy(T: Ty))
941 Reg = DT.find(PointerElementType: cast<TypedPointerType>(Val: Ty)->getElementType(),
942 AddressSpace: getPointerAddressSpace(T: Ty), MF: &MIRBuilder.getMF());
943 else
944 Reg =
945 DT.find(PointerElementType: Type::getInt8Ty(C&: MIRBuilder.getMF().getFunction().getContext()),
946 AddressSpace: getPointerAddressSpace(T: Ty), MF: &MIRBuilder.getMF());
947
948 if (Reg.isValid() && !isSpecialOpaqueType(Ty))
949 return getSPIRVTypeForVReg(VReg: Reg);
950 TypesInProcessing.clear();
951 SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
952 // Create normal pointer types for the corresponding OpTypeForwardPointers.
953 for (auto &CU : ForwardPointerTypes) {
954 const Type *Ty2 = CU.first;
955 SPIRVType *STy2 = CU.second;
956 if ((Reg = DT.find(Ty: Ty2, MF: &MIRBuilder.getMF())).isValid())
957 STy2 = getSPIRVTypeForVReg(VReg: Reg);
958 else
959 STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
960 if (Ty == Ty2)
961 STy = STy2;
962 }
963 ForwardPointerTypes.clear();
964 return STy;
965}
966
967bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
968 unsigned TypeOpcode) const {
969 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
970 assert(Type && "isScalarOfType VReg has no type assigned");
971 return Type->getOpcode() == TypeOpcode;
972}
973
974bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
975 unsigned TypeOpcode) const {
976 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
977 assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
978 if (Type->getOpcode() == TypeOpcode)
979 return true;
980 if (Type->getOpcode() == SPIRV::OpTypeVector) {
981 Register ScalarTypeVReg = Type->getOperand(i: 1).getReg();
982 SPIRVType *ScalarType = getSPIRVTypeForVReg(VReg: ScalarTypeVReg);
983 return ScalarType->getOpcode() == TypeOpcode;
984 }
985 return false;
986}
987
988unsigned
989SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
990 return getScalarOrVectorComponentCount(Type: getSPIRVTypeForVReg(VReg));
991}
992
993unsigned
994SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const {
995 if (!Type)
996 return 0;
997 return Type->getOpcode() == SPIRV::OpTypeVector
998 ? static_cast<unsigned>(Type->getOperand(2).getImm())
999 : 1;
1000}
1001
1002unsigned
1003SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
1004 assert(Type && "Invalid Type pointer");
1005 if (Type->getOpcode() == SPIRV::OpTypeVector) {
1006 auto EleTypeReg = Type->getOperand(i: 1).getReg();
1007 Type = getSPIRVTypeForVReg(VReg: EleTypeReg);
1008 }
1009 if (Type->getOpcode() == SPIRV::OpTypeInt ||
1010 Type->getOpcode() == SPIRV::OpTypeFloat)
1011 return Type->getOperand(i: 1).getImm();
1012 if (Type->getOpcode() == SPIRV::OpTypeBool)
1013 return 1;
1014 llvm_unreachable("Attempting to get bit width of non-integer/float type.");
1015}
1016
1017unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
1018 const SPIRVType *Type) const {
1019 assert(Type && "Invalid Type pointer");
1020 unsigned NumElements = 1;
1021 if (Type->getOpcode() == SPIRV::OpTypeVector) {
1022 NumElements = static_cast<unsigned>(Type->getOperand(i: 2).getImm());
1023 Type = getSPIRVTypeForVReg(VReg: Type->getOperand(i: 1).getReg());
1024 }
1025 return Type->getOpcode() == SPIRV::OpTypeInt ||
1026 Type->getOpcode() == SPIRV::OpTypeFloat
1027 ? NumElements * Type->getOperand(1).getImm()
1028 : 0;
1029}
1030
1031const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
1032 const SPIRVType *Type) const {
1033 if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
1034 Type = getSPIRVTypeForVReg(VReg: Type->getOperand(i: 1).getReg());
1035 return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
1036}
1037
1038bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
1039 const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type);
1040 return IntType && IntType->getOperand(i: 2).getImm() != 0;
1041}
1042
1043unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) {
1044 SPIRVType *PtrType = getSPIRVTypeForVReg(VReg: PtrReg);
1045 SPIRVType *ElemType =
1046 PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
1047 ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
1048 : nullptr;
1049 return ElemType ? ElemType->getOpcode() : 0;
1050}
1051
1052bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1,
1053 const SPIRVType *Type2) const {
1054 if (!Type1 || !Type2)
1055 return false;
1056 auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode();
1057 // Ignore difference between <1.5 and >=1.5 protocol versions:
1058 // it's valid if either Result Type or Operand is a pointer, and the other
1059 // is a pointer, an integer scalar, or an integer vector.
1060 if (Op1 == SPIRV::OpTypePointer &&
1061 (Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2)))
1062 return true;
1063 if (Op2 == SPIRV::OpTypePointer &&
1064 (Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1)))
1065 return true;
1066 unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type: Type1),
1067 Bits2 = getNumScalarOrVectorTotalBitWidth(Type: Type2);
1068 return Bits1 > 0 && Bits1 == Bits2;
1069}
1070
1071SPIRV::StorageClass::StorageClass
1072SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
1073 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1074 assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
1075 Type->getOperand(1).isImm() && "Pointer type is expected");
1076 return static_cast<SPIRV::StorageClass::StorageClass>(
1077 Type->getOperand(i: 1).getImm());
1078}
1079
1080SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
1081 MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
1082 uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
1083 SPIRV::ImageFormat::ImageFormat ImageFormat,
1084 SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1085 SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(Val: SampledType), Dim, Depth,
1086 Arrayed, Multisampled, Sampled, ImageFormat,
1087 AccessQual);
1088 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1089 return Res;
1090 Register ResVReg = createTypeVReg(MIRBuilder);
1091 DT.add(TD, MF: &MIRBuilder.getMF(), R: ResVReg);
1092 return MIRBuilder.buildInstr(SPIRV::OpTypeImage)
1093 .addDef(ResVReg)
1094 .addUse(getSPIRVTypeID(SampledType))
1095 .addImm(Dim)
1096 .addImm(Depth) // Depth (whether or not it is a Depth image).
1097 .addImm(Arrayed) // Arrayed.
1098 .addImm(Multisampled) // Multisampled (0 = only single-sample).
1099 .addImm(Sampled) // Sampled (0 = usage known at runtime).
1100 .addImm(ImageFormat)
1101 .addImm(AccessQual);
1102}
1103
1104SPIRVType *
1105SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
1106 SPIRV::SamplerTypeDescriptor TD;
1107 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1108 return Res;
1109 Register ResVReg = createTypeVReg(MIRBuilder);
1110 DT.add(TD, MF: &MIRBuilder.getMF(), R: ResVReg);
1111 return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg);
1112}
1113
1114SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
1115 MachineIRBuilder &MIRBuilder,
1116 SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1117 SPIRV::PipeTypeDescriptor TD(AccessQual);
1118 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1119 return Res;
1120 Register ResVReg = createTypeVReg(MIRBuilder);
1121 DT.add(TD, MF: &MIRBuilder.getMF(), R: ResVReg);
1122 return MIRBuilder.buildInstr(SPIRV::OpTypePipe)
1123 .addDef(ResVReg)
1124 .addImm(AccessQual);
1125}
1126
1127SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
1128 MachineIRBuilder &MIRBuilder) {
1129 SPIRV::DeviceEventTypeDescriptor TD;
1130 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1131 return Res;
1132 Register ResVReg = createTypeVReg(MIRBuilder);
1133 DT.add(TD, MF: &MIRBuilder.getMF(), R: ResVReg);
1134 return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg);
1135}
1136
1137SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
1138 SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
1139 SPIRV::SampledImageTypeDescriptor TD(
1140 SPIRVToLLVMType.lookup(Val: MIRBuilder.getMF().getRegInfo().getVRegDef(
1141 Reg: ImageType->getOperand(i: 1).getReg())),
1142 ImageType);
1143 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1144 return Res;
1145 Register ResVReg = createTypeVReg(MIRBuilder);
1146 DT.add(TD, MF: &MIRBuilder.getMF(), R: ResVReg);
1147 return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage)
1148 .addDef(ResVReg)
1149 .addUse(getSPIRVTypeID(ImageType));
1150}
1151
1152SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
1153 const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
1154 Register ResVReg = DT.find(Ty, MF: &MIRBuilder.getMF());
1155 if (ResVReg.isValid())
1156 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg: ResVReg);
1157 ResVReg = createTypeVReg(MIRBuilder);
1158 SPIRVType *SpirvTy = MIRBuilder.buildInstr(Opcode).addDef(RegNo: ResVReg);
1159 DT.add(Ty, MF: &MIRBuilder.getMF(), R: ResVReg);
1160 return SpirvTy;
1161}
1162
1163const MachineInstr *
1164SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
1165 MachineIRBuilder &MIRBuilder) {
1166 Register Reg = DT.find(TD, MF: &MIRBuilder.getMF());
1167 if (Reg.isValid())
1168 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg);
1169 return nullptr;
1170}
1171
1172// Returns nullptr if unable to recognize SPIRV type name
1173SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
1174 StringRef TypeStr, MachineIRBuilder &MIRBuilder,
1175 SPIRV::StorageClass::StorageClass SC,
1176 SPIRV::AccessQualifier::AccessQualifier AQ) {
1177 unsigned VecElts = 0;
1178 auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
1179
1180 // Parse strings representing either a SPIR-V or OpenCL builtin type.
1181 if (hasBuiltinTypePrefix(Name: TypeStr))
1182 return getOrCreateSPIRVType(BitWidth: SPIRV::parseBuiltinTypeNameToTargetExtType(
1183 TypeName: TypeStr.str(), Context&: MIRBuilder.getContext()),
1184 I&: MIRBuilder, TII: AQ);
1185
1186 // Parse type name in either "typeN" or "type vector[N]" format, where
1187 // N is the number of elements of the vector.
1188 Type *Ty;
1189
1190 Ty = parseBasicTypeName(TypeName&: TypeStr, Ctx);
1191 if (!Ty)
1192 // Unable to recognize SPIRV type name
1193 return nullptr;
1194
1195 auto SpirvTy = getOrCreateSPIRVType(BitWidth: Ty, I&: MIRBuilder, TII: AQ);
1196
1197 // Handle "type*" or "type* vector[N]".
1198 if (TypeStr.starts_with(Prefix: "*")) {
1199 SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1200 TypeStr = TypeStr.substr(Start: strlen(s: "*"));
1201 }
1202
1203 // Handle "typeN*" or "type vector[N]*".
1204 bool IsPtrToVec = TypeStr.consume_back(Suffix: "*");
1205
1206 if (TypeStr.consume_front(Prefix: " vector[")) {
1207 TypeStr = TypeStr.substr(Start: 0, N: TypeStr.find(C: ']'));
1208 }
1209 TypeStr.getAsInteger(Radix: 10, Result&: VecElts);
1210 if (VecElts > 0)
1211 SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);
1212
1213 if (IsPtrToVec)
1214 SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1215
1216 return SpirvTy;
1217}
1218
1219SPIRVType *
1220SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
1221 MachineIRBuilder &MIRBuilder) {
1222 return getOrCreateSPIRVType(
1223 IntegerType::get(C&: MIRBuilder.getMF().getFunction().getContext(), NumBits: BitWidth),
1224 MIRBuilder);
1225}
1226
1227SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
1228 SPIRVType *SpirvType) {
1229 assert(CurMF == SpirvType->getMF());
1230 VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
1231 SPIRVToLLVMType[SpirvType] = LLVMTy;
1232 return SpirvType;
1233}
1234
1235SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
1236 MachineInstr &I,
1237 const SPIRVInstrInfo &TII,
1238 unsigned SPIRVOPcode,
1239 Type *LLVMTy) {
1240 Register Reg = DT.find(Ty: LLVMTy, MF: CurMF);
1241 if (Reg.isValid())
1242 return getSPIRVTypeForVReg(VReg: Reg);
1243 MachineBasicBlock &BB = *I.getParent();
1244 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRVOPcode))
1245 .addDef(createTypeVReg(MRI&: CurMF->getRegInfo()))
1246 .addImm(BitWidth)
1247 .addImm(0);
1248 DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType: MIB));
1249 return finishCreatingSPIRVType(LLVMTy, SpirvType: MIB);
1250}
1251
1252SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
1253 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1254 Type *LLVMTy = IntegerType::get(C&: CurMF->getFunction().getContext(), NumBits: BitWidth);
1255 return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
1256}
1257SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
1258 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1259 LLVMContext &Ctx = CurMF->getFunction().getContext();
1260 Type *LLVMTy;
1261 switch (BitWidth) {
1262 case 16:
1263 LLVMTy = Type::getHalfTy(C&: Ctx);
1264 break;
1265 case 32:
1266 LLVMTy = Type::getFloatTy(C&: Ctx);
1267 break;
1268 case 64:
1269 LLVMTy = Type::getDoubleTy(C&: Ctx);
1270 break;
1271 default:
1272 llvm_unreachable("Bit width is of unexpected size.");
1273 }
1274 return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
1275}
1276
1277SPIRVType *
1278SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
1279 return getOrCreateSPIRVType(
1280 IntegerType::get(C&: MIRBuilder.getMF().getFunction().getContext(), NumBits: 1),
1281 MIRBuilder);
1282}
1283
1284SPIRVType *
1285SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
1286 const SPIRVInstrInfo &TII) {
1287 Type *LLVMTy = IntegerType::get(C&: CurMF->getFunction().getContext(), NumBits: 1);
1288 Register Reg = DT.find(Ty: LLVMTy, MF: CurMF);
1289 if (Reg.isValid())
1290 return getSPIRVTypeForVReg(VReg: Reg);
1291 MachineBasicBlock &BB = *I.getParent();
1292 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))
1293 .addDef(createTypeVReg(CurMF->getRegInfo()));
1294 DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType: MIB));
1295 return finishCreatingSPIRVType(LLVMTy, SpirvType: MIB);
1296}
1297
1298SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1299 SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
1300 return getOrCreateSPIRVType(
1301 FixedVectorType::get(ElementType: const_cast<Type *>(getTypeForSPIRVType(Ty: BaseType)),
1302 NumElts: NumElements),
1303 MIRBuilder);
1304}
1305
1306SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1307 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1308 const SPIRVInstrInfo &TII) {
1309 Type *LLVMTy = FixedVectorType::get(
1310 ElementType: const_cast<Type *>(getTypeForSPIRVType(Ty: BaseType)), NumElts: NumElements);
1311 Register Reg = DT.find(Ty: LLVMTy, MF: CurMF);
1312 if (Reg.isValid())
1313 return getSPIRVTypeForVReg(VReg: Reg);
1314 MachineBasicBlock &BB = *I.getParent();
1315 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
1316 .addDef(createTypeVReg(CurMF->getRegInfo()))
1317 .addUse(getSPIRVTypeID(BaseType))
1318 .addImm(NumElements);
1319 DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType: MIB));
1320 return finishCreatingSPIRVType(LLVMTy, SpirvType: MIB);
1321}
1322
1323SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
1324 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1325 const SPIRVInstrInfo &TII) {
1326 Type *LLVMTy = ArrayType::get(
1327 ElementType: const_cast<Type *>(getTypeForSPIRVType(Ty: BaseType)), NumElements);
1328 Register Reg = DT.find(Ty: LLVMTy, MF: CurMF);
1329 if (Reg.isValid())
1330 return getSPIRVTypeForVReg(VReg: Reg);
1331 MachineBasicBlock &BB = *I.getParent();
1332 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth: 32, I, TII);
1333 Register Len = getOrCreateConstInt(Val: NumElements, I, SpvType: SpirvType, TII);
1334 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))
1335 .addDef(createTypeVReg(CurMF->getRegInfo()))
1336 .addUse(getSPIRVTypeID(BaseType))
1337 .addUse(Len);
1338 DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType: MIB));
1339 return finishCreatingSPIRVType(LLVMTy, SpirvType: MIB);
1340}
1341
1342SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1343 SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1344 SPIRV::StorageClass::StorageClass SC) {
1345 const Type *PointerElementType = getTypeForSPIRVType(Ty: BaseType);
1346 unsigned AddressSpace = storageClassToAddressSpace(SC);
1347 Type *LLVMTy = TypedPointerType::get(ElementType: const_cast<Type *>(PointerElementType),
1348 AddressSpace);
1349 // check if this type is already available
1350 Register Reg = DT.find(PointerElementType, AddressSpace, MF: CurMF);
1351 if (Reg.isValid())
1352 return getSPIRVTypeForVReg(VReg: Reg);
1353 // create a new type
1354 auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),
1355 MIRBuilder.getDebugLoc(),
1356 MIRBuilder.getTII().get(SPIRV::OpTypePointer))
1357 .addDef(createTypeVReg(CurMF->getRegInfo()))
1358 .addImm(static_cast<uint32_t>(SC))
1359 .addUse(getSPIRVTypeID(BaseType));
1360 DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(SpirvType: MIB));
1361 return finishCreatingSPIRVType(LLVMTy, SpirvType: MIB);
1362}
1363
1364SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1365 SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &,
1366 SPIRV::StorageClass::StorageClass SC) {
1367 MachineIRBuilder MIRBuilder(I);
1368 return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
1369}
1370
1371Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1372 SPIRVType *SpvType,
1373 const SPIRVInstrInfo &TII) {
1374 assert(SpvType);
1375 const Type *LLVMTy = getTypeForSPIRVType(Ty: SpvType);
1376 assert(LLVMTy);
1377 // Find a constant in DT or build a new one.
1378 UndefValue *UV = UndefValue::get(T: const_cast<Type *>(LLVMTy));
1379 Register Res = DT.find(C: UV, MF: CurMF);
1380 if (Res.isValid())
1381 return Res;
1382 LLT LLTy = LLT::scalar(SizeInBits: 32);
1383 Res = CurMF->getRegInfo().createGenericVirtualRegister(Ty: LLTy);
1384 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
1385 assignSPIRVTypeToVReg(SpirvType: SpvType, VReg: Res, MF&: *CurMF);
1386 DT.add(C: UV, MF: CurMF, R: Res);
1387
1388 MachineInstrBuilder MIB;
1389 MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
1390 .addDef(Res)
1391 .addUse(getSPIRVTypeID(SpvType));
1392 const auto &ST = CurMF->getSubtarget();
1393 constrainSelectedInstRegOperands(I&: *MIB, TII: *ST.getInstrInfo(),
1394 TRI: *ST.getRegisterInfo(), RBI: *ST.getRegBankInfo());
1395 return Res;
1396}
1397

source code of llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp