1//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SPIRVTargetLowering class.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVISelLowering.h"
14#include "SPIRV.h"
15#include "SPIRVInstrInfo.h"
16#include "SPIRVRegisterBankInfo.h"
17#include "SPIRVRegisterInfo.h"
18#include "SPIRVSubtarget.h"
19#include "SPIRVTargetMachine.h"
20#include "llvm/CodeGen/MachineInstrBuilder.h"
21#include "llvm/CodeGen/MachineRegisterInfo.h"
22#include "llvm/IR/IntrinsicsSPIRV.h"
23
24#define DEBUG_TYPE "spirv-lower"
25
26using namespace llvm;
27
28unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
29 LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
30 // This code avoids CallLowering fail inside getVectorTypeBreakdown
31 // on v3i1 arguments. Maybe we need to return 1 for all types.
32 // TODO: remove it once this case is supported by the default implementation.
33 if (VT.isVector() && VT.getVectorNumElements() == 3 &&
34 (VT.getVectorElementType() == MVT::i1 ||
35 VT.getVectorElementType() == MVT::i8))
36 return 1;
37 if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
38 return 1;
39 return getNumRegisters(Context, VT);
40}
41
42MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
43 CallingConv::ID CC,
44 EVT VT) const {
45 // This code avoids CallLowering fail inside getVectorTypeBreakdown
46 // on v3i1 arguments. Maybe we need to return i32 for all types.
47 // TODO: remove it once this case is supported by the default implementation.
48 if (VT.isVector() && VT.getVectorNumElements() == 3) {
49 if (VT.getVectorElementType() == MVT::i1)
50 return MVT::v4i1;
51 else if (VT.getVectorElementType() == MVT::i8)
52 return MVT::v4i8;
53 }
54 return getRegisterType(Context, VT);
55}
56
57bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
58 const CallInst &I,
59 MachineFunction &MF,
60 unsigned Intrinsic) const {
61 unsigned AlignIdx = 3;
62 switch (Intrinsic) {
63 case Intrinsic::spv_load:
64 AlignIdx = 2;
65 [[fallthrough]];
66 case Intrinsic::spv_store: {
67 if (I.getNumOperands() >= AlignIdx + 1) {
68 auto *AlignOp = cast<ConstantInt>(Val: I.getOperand(i_nocapture: AlignIdx));
69 Info.align = Align(AlignOp->getZExtValue());
70 }
71 Info.flags = static_cast<MachineMemOperand::Flags>(
72 cast<ConstantInt>(Val: I.getOperand(i_nocapture: AlignIdx - 1))->getZExtValue());
73 Info.memVT = MVT::i64;
74 // TODO: take into account opaque pointers (don't use getElementType).
75 // MVT::getVT(PtrTy->getElementType());
76 return true;
77 break;
78 }
79 default:
80 break;
81 }
82 return false;
83}
84
85// Insert a bitcast before the instruction to keep SPIR-V code valid
86// when there is a type mismatch between results and operand types.
87static void validatePtrTypes(const SPIRVSubtarget &STI,
88 MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
89 MachineInstr &I, unsigned OpIdx,
90 SPIRVType *ResType, const Type *ResTy = nullptr) {
91 // Get operand type
92 MachineFunction *MF = I.getParent()->getParent();
93 Register OpReg = I.getOperand(i: OpIdx).getReg();
94 SPIRVType *TypeInst = MRI->getVRegDef(Reg: OpReg);
95 Register OpTypeReg =
96 TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
97 ? TypeInst->getOperand(i: 1).getReg()
98 : OpReg;
99 SPIRVType *OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF);
100 if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
101 return;
102 // Get operand's pointee type
103 Register ElemTypeReg = OpType->getOperand(i: 2).getReg();
104 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(VReg: ElemTypeReg, MF);
105 if (!ElemType)
106 return;
107 // Check if we need a bitcast to make a statement valid
108 bool IsSameMF = MF == ResType->getParent()->getParent();
109 bool IsEqualTypes = IsSameMF ? ElemType == ResType
110 : GR.getTypeForSPIRVType(Ty: ElemType) == ResTy;
111 if (IsEqualTypes)
112 return;
113 // There is a type mismatch between results and operand types
114 // and we insert a bitcast before the instruction to keep SPIR-V code valid
115 SPIRV::StorageClass::StorageClass SC =
116 static_cast<SPIRV::StorageClass::StorageClass>(
117 OpType->getOperand(i: 1).getImm());
118 MachineIRBuilder MIB(I);
119 SPIRVType *NewBaseType =
120 IsSameMF ? ResType
121 : GR.getOrCreateSPIRVType(
122 BitWidth: ResTy, I&: MIB, SPIRV::AccessQualifier::TII: ReadWrite, SPIRVOPcode: false);
123 SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(BaseType: NewBaseType, MIRBuilder&: MIB, SClass: SC);
124 if (!GR.isBitcastCompatible(Type1: NewPtrType, Type2: OpType))
125 report_fatal_error(
126 reason: "insert validation bitcast: incompatible result and operand types");
127 Register NewReg = MRI->createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 32));
128 bool Res = MIB.buildInstr(SPIRV::OpBitcast)
129 .addDef(NewReg)
130 .addUse(GR.getSPIRVTypeID(SpirvType: NewPtrType))
131 .addUse(OpReg)
132 .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
133 *STI.getRegBankInfo());
134 if (!Res)
135 report_fatal_error(reason: "insert validation bitcast: cannot constrain all uses");
136 MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
137 GR.assignSPIRVTypeToVReg(Type: NewPtrType, VReg: NewReg, MF&: MIB.getMF());
138 I.getOperand(i: OpIdx).setReg(NewReg);
139}
140
141// Insert a bitcast before the function call instruction to keep SPIR-V code
142// valid when there is a type mismatch between actual and expected types of an
143// argument:
144// %formal = OpFunctionParameter %formal_type
145// ...
146// %res = OpFunctionCall %ty %fun %actual ...
147// implies that %actual is of %formal_type, and in case of opaque pointers.
148// We may need to insert a bitcast to ensure this.
149void validateFunCallMachineDef(const SPIRVSubtarget &STI,
150 MachineRegisterInfo *DefMRI,
151 MachineRegisterInfo *CallMRI,
152 SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
153 MachineInstr *FunDef) {
154 if (FunDef->getOpcode() != SPIRV::OpFunction)
155 return;
156 unsigned OpIdx = 3;
157 for (FunDef = FunDef->getNextNode();
158 FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
159 OpIdx < FunCall.getNumOperands();
160 FunDef = FunDef->getNextNode(), OpIdx++) {
161 SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
162 SPIRVType *DefElemType =
163 DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
164 ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
165 DefPtrType->getParent()->getParent())
166 : nullptr;
167 if (DefElemType) {
168 const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
169 // validatePtrTypes() works in the context if the call site
170 // When we process historical records about forward calls
171 // we need to switch context to the (forward) call site and
172 // then restore it back to the current machine function.
173 MachineFunction *CurMF =
174 GR.setCurrentFunc(*FunCall.getParent()->getParent());
175 validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
176 DefElemTy);
177 GR.setCurrentFunc(*CurMF);
178 }
179 }
180}
181
182// Ensure there is no mismatch between actual and expected arg types: calls
183// with a processed definition. Return Function pointer if it's a forward
184// call (ahead of definition), and nullptr otherwise.
185const Function *validateFunCall(const SPIRVSubtarget &STI,
186 MachineRegisterInfo *CallMRI,
187 SPIRVGlobalRegistry &GR,
188 MachineInstr &FunCall) {
189 const GlobalValue *GV = FunCall.getOperand(i: 2).getGlobal();
190 const Function *F = dyn_cast<Function>(Val: GV);
191 MachineInstr *FunDef =
192 const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
193 if (!FunDef)
194 return F;
195 MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
196 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
197 return nullptr;
198}
199
200// Ensure there is no mismatch between actual and expected arg types: calls
201// ahead of a processed definition.
202void validateForwardCalls(const SPIRVSubtarget &STI,
203 MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
204 MachineInstr &FunDef) {
205 const Function *F = GR.getFunctionByDefinition(MI: &FunDef);
206 if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F))
207 for (MachineInstr *FunCall : *FwdCalls) {
208 MachineRegisterInfo *CallMRI =
209 &FunCall->getParent()->getParent()->getRegInfo();
210 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall&: *FunCall, FunDef: &FunDef);
211 }
212}
213
214// Validation of an access chain.
215void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
216 SPIRVGlobalRegistry &GR, MachineInstr &I) {
217 SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(VReg: I.getOperand(i: 0).getReg());
218 if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
219 SPIRVType *BaseElemType =
220 GR.getSPIRVTypeForVReg(VReg: BaseTypeInst->getOperand(i: 2).getReg());
221 validatePtrTypes(STI, MRI, GR, I, OpIdx: 2, ResType: BaseElemType);
222 }
223}
224
225// TODO: the logic of inserting additional bitcast's is to be moved
226// to pre-IRTranslation passes eventually
227void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
228 // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
229 // We'd like to avoid the needless second processing pass.
230 if (ProcessedMF.find(x: &MF) != ProcessedMF.end())
231 return;
232
233 MachineRegisterInfo *MRI = &MF.getRegInfo();
234 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
235 GR.setCurrentFunc(MF);
236 for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
237 MachineBasicBlock *MBB = &*I;
238 for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
239 MBBI != MBBE;) {
240 MachineInstr &MI = *MBBI++;
241 switch (MI.getOpcode()) {
242 case SPIRV::OpAtomicLoad:
243 case SPIRV::OpAtomicExchange:
244 case SPIRV::OpAtomicCompareExchange:
245 case SPIRV::OpAtomicCompareExchangeWeak:
246 case SPIRV::OpAtomicIIncrement:
247 case SPIRV::OpAtomicIDecrement:
248 case SPIRV::OpAtomicIAdd:
249 case SPIRV::OpAtomicISub:
250 case SPIRV::OpAtomicSMin:
251 case SPIRV::OpAtomicUMin:
252 case SPIRV::OpAtomicSMax:
253 case SPIRV::OpAtomicUMax:
254 case SPIRV::OpAtomicAnd:
255 case SPIRV::OpAtomicOr:
256 case SPIRV::OpAtomicXor:
257 // for the above listed instructions
258 // OpAtomicXXX <ResType>, ptr %Op, ...
259 // implies that %Op is a pointer to <ResType>
260 case SPIRV::OpLoad:
261 // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
262 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 2,
263 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 0).getReg()));
264 break;
265 case SPIRV::OpAtomicStore:
266 // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
267 // implies that %Op points to the <Obj>'s type
268 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 0,
269 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 3).getReg()));
270 break;
271 case SPIRV::OpStore:
272 // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
273 validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 0,
274 ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg()));
275 break;
276 case SPIRV::OpPtrCastToGeneric:
277 validateAccessChain(STI, MRI, GR, I&: MI);
278 break;
279 case SPIRV::OpInBoundsPtrAccessChain:
280 if (MI.getNumOperands() == 4)
281 validateAccessChain(STI, MRI, GR, I&: MI);
282 break;
283
284 case SPIRV::OpFunctionCall:
285 // ensure there is no mismatch between actual and expected arg types:
286 // calls with a processed definition
287 if (MI.getNumOperands() > 3)
288 if (const Function *F = validateFunCall(STI, CallMRI: MRI, GR, FunCall&: MI))
289 GR.addForwardCall(F, MI: &MI);
290 break;
291 case SPIRV::OpFunction:
292 // ensure there is no mismatch between actual and expected arg types:
293 // calls ahead of a processed definition
294 validateForwardCalls(STI, DefMRI: MRI, GR, FunDef&: MI);
295 break;
296
297 // ensure that LLVM IR bitwise instructions result in logical SPIR-V
298 // instructions when applied to bool type
299 case SPIRV::OpBitwiseOrS:
300 case SPIRV::OpBitwiseOrV:
301 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
302 SPIRV::OpTypeBool))
303 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr));
304 break;
305 case SPIRV::OpBitwiseAndS:
306 case SPIRV::OpBitwiseAndV:
307 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
308 SPIRV::OpTypeBool))
309 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd));
310 break;
311 case SPIRV::OpBitwiseXorS:
312 case SPIRV::OpBitwiseXorV:
313 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
314 SPIRV::OpTypeBool))
315 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
316 break;
317 }
318 }
319 }
320 ProcessedMF.insert(x: &MF);
321 TargetLowering::finalizeLowering(MF);
322}
323

source code of llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp