1 | //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the SPIRVTargetLowering class. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "SPIRVISelLowering.h" |
14 | #include "SPIRV.h" |
15 | #include "SPIRVInstrInfo.h" |
16 | #include "SPIRVRegisterBankInfo.h" |
17 | #include "SPIRVRegisterInfo.h" |
18 | #include "SPIRVSubtarget.h" |
19 | #include "SPIRVTargetMachine.h" |
20 | #include "llvm/CodeGen/MachineInstrBuilder.h" |
21 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
22 | #include "llvm/IR/IntrinsicsSPIRV.h" |
23 | |
24 | #define DEBUG_TYPE "spirv-lower" |
25 | |
26 | using namespace llvm; |
27 | |
28 | unsigned SPIRVTargetLowering::getNumRegistersForCallingConv( |
29 | LLVMContext &Context, CallingConv::ID CC, EVT VT) const { |
30 | // This code avoids CallLowering fail inside getVectorTypeBreakdown |
31 | // on v3i1 arguments. Maybe we need to return 1 for all types. |
32 | // TODO: remove it once this case is supported by the default implementation. |
33 | if (VT.isVector() && VT.getVectorNumElements() == 3 && |
34 | (VT.getVectorElementType() == MVT::i1 || |
35 | VT.getVectorElementType() == MVT::i8)) |
36 | return 1; |
37 | if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64) |
38 | return 1; |
39 | return getNumRegisters(Context, VT); |
40 | } |
41 | |
42 | MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context, |
43 | CallingConv::ID CC, |
44 | EVT VT) const { |
45 | // This code avoids CallLowering fail inside getVectorTypeBreakdown |
46 | // on v3i1 arguments. Maybe we need to return i32 for all types. |
47 | // TODO: remove it once this case is supported by the default implementation. |
48 | if (VT.isVector() && VT.getVectorNumElements() == 3) { |
49 | if (VT.getVectorElementType() == MVT::i1) |
50 | return MVT::v4i1; |
51 | else if (VT.getVectorElementType() == MVT::i8) |
52 | return MVT::v4i8; |
53 | } |
54 | return getRegisterType(Context, VT); |
55 | } |
56 | |
57 | bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, |
58 | const CallInst &I, |
59 | MachineFunction &MF, |
60 | unsigned Intrinsic) const { |
61 | unsigned AlignIdx = 3; |
62 | switch (Intrinsic) { |
63 | case Intrinsic::spv_load: |
64 | AlignIdx = 2; |
65 | [[fallthrough]]; |
66 | case Intrinsic::spv_store: { |
67 | if (I.getNumOperands() >= AlignIdx + 1) { |
68 | auto *AlignOp = cast<ConstantInt>(Val: I.getOperand(i_nocapture: AlignIdx)); |
69 | Info.align = Align(AlignOp->getZExtValue()); |
70 | } |
71 | Info.flags = static_cast<MachineMemOperand::Flags>( |
72 | cast<ConstantInt>(Val: I.getOperand(i_nocapture: AlignIdx - 1))->getZExtValue()); |
73 | Info.memVT = MVT::i64; |
74 | // TODO: take into account opaque pointers (don't use getElementType). |
75 | // MVT::getVT(PtrTy->getElementType()); |
76 | return true; |
77 | break; |
78 | } |
79 | default: |
80 | break; |
81 | } |
82 | return false; |
83 | } |
84 | |
85 | // Insert a bitcast before the instruction to keep SPIR-V code valid |
86 | // when there is a type mismatch between results and operand types. |
87 | static void validatePtrTypes(const SPIRVSubtarget &STI, |
88 | MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, |
89 | MachineInstr &I, unsigned OpIdx, |
90 | SPIRVType *ResType, const Type *ResTy = nullptr) { |
91 | // Get operand type |
92 | MachineFunction *MF = I.getParent()->getParent(); |
93 | Register OpReg = I.getOperand(i: OpIdx).getReg(); |
94 | SPIRVType *TypeInst = MRI->getVRegDef(Reg: OpReg); |
95 | Register OpTypeReg = |
96 | TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter |
97 | ? TypeInst->getOperand(i: 1).getReg() |
98 | : OpReg; |
99 | SPIRVType *OpType = GR.getSPIRVTypeForVReg(VReg: OpTypeReg, MF); |
100 | if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer) |
101 | return; |
102 | // Get operand's pointee type |
103 | Register ElemTypeReg = OpType->getOperand(i: 2).getReg(); |
104 | SPIRVType *ElemType = GR.getSPIRVTypeForVReg(VReg: ElemTypeReg, MF); |
105 | if (!ElemType) |
106 | return; |
107 | // Check if we need a bitcast to make a statement valid |
108 | bool IsSameMF = MF == ResType->getParent()->getParent(); |
109 | bool IsEqualTypes = IsSameMF ? ElemType == ResType |
110 | : GR.getTypeForSPIRVType(Ty: ElemType) == ResTy; |
111 | if (IsEqualTypes) |
112 | return; |
113 | // There is a type mismatch between results and operand types |
114 | // and we insert a bitcast before the instruction to keep SPIR-V code valid |
115 | SPIRV::StorageClass::StorageClass SC = |
116 | static_cast<SPIRV::StorageClass::StorageClass>( |
117 | OpType->getOperand(i: 1).getImm()); |
118 | MachineIRBuilder MIB(I); |
119 | SPIRVType *NewBaseType = |
120 | IsSameMF ? ResType |
121 | : GR.getOrCreateSPIRVType( |
122 | BitWidth: ResTy, I&: MIB, SPIRV::AccessQualifier::TII: ReadWrite, SPIRVOPcode: false); |
123 | SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(BaseType: NewBaseType, MIRBuilder&: MIB, SClass: SC); |
124 | if (!GR.isBitcastCompatible(Type1: NewPtrType, Type2: OpType)) |
125 | report_fatal_error( |
126 | reason: "insert validation bitcast: incompatible result and operand types" ); |
127 | Register NewReg = MRI->createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 32)); |
128 | bool Res = MIB.buildInstr(SPIRV::OpBitcast) |
129 | .addDef(NewReg) |
130 | .addUse(GR.getSPIRVTypeID(SpirvType: NewPtrType)) |
131 | .addUse(OpReg) |
132 | .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(), |
133 | *STI.getRegBankInfo()); |
134 | if (!Res) |
135 | report_fatal_error(reason: "insert validation bitcast: cannot constrain all uses" ); |
136 | MRI->setRegClass(NewReg, &SPIRV::IDRegClass); |
137 | GR.assignSPIRVTypeToVReg(Type: NewPtrType, VReg: NewReg, MF&: MIB.getMF()); |
138 | I.getOperand(i: OpIdx).setReg(NewReg); |
139 | } |
140 | |
141 | // Insert a bitcast before the function call instruction to keep SPIR-V code |
142 | // valid when there is a type mismatch between actual and expected types of an |
143 | // argument: |
144 | // %formal = OpFunctionParameter %formal_type |
145 | // ... |
146 | // %res = OpFunctionCall %ty %fun %actual ... |
147 | // implies that %actual is of %formal_type, and in case of opaque pointers. |
148 | // We may need to insert a bitcast to ensure this. |
149 | void validateFunCallMachineDef(const SPIRVSubtarget &STI, |
150 | MachineRegisterInfo *DefMRI, |
151 | MachineRegisterInfo *CallMRI, |
152 | SPIRVGlobalRegistry &GR, MachineInstr &FunCall, |
153 | MachineInstr *FunDef) { |
154 | if (FunDef->getOpcode() != SPIRV::OpFunction) |
155 | return; |
156 | unsigned OpIdx = 3; |
157 | for (FunDef = FunDef->getNextNode(); |
158 | FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter && |
159 | OpIdx < FunCall.getNumOperands(); |
160 | FunDef = FunDef->getNextNode(), OpIdx++) { |
161 | SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg()); |
162 | SPIRVType *DefElemType = |
163 | DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer |
164 | ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(), |
165 | DefPtrType->getParent()->getParent()) |
166 | : nullptr; |
167 | if (DefElemType) { |
168 | const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType); |
169 | // validatePtrTypes() works in the context if the call site |
170 | // When we process historical records about forward calls |
171 | // we need to switch context to the (forward) call site and |
172 | // then restore it back to the current machine function. |
173 | MachineFunction *CurMF = |
174 | GR.setCurrentFunc(*FunCall.getParent()->getParent()); |
175 | validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType, |
176 | DefElemTy); |
177 | GR.setCurrentFunc(*CurMF); |
178 | } |
179 | } |
180 | } |
181 | |
182 | // Ensure there is no mismatch between actual and expected arg types: calls |
183 | // with a processed definition. Return Function pointer if it's a forward |
184 | // call (ahead of definition), and nullptr otherwise. |
185 | const Function *validateFunCall(const SPIRVSubtarget &STI, |
186 | MachineRegisterInfo *CallMRI, |
187 | SPIRVGlobalRegistry &GR, |
188 | MachineInstr &FunCall) { |
189 | const GlobalValue *GV = FunCall.getOperand(i: 2).getGlobal(); |
190 | const Function *F = dyn_cast<Function>(Val: GV); |
191 | MachineInstr *FunDef = |
192 | const_cast<MachineInstr *>(GR.getFunctionDefinition(F)); |
193 | if (!FunDef) |
194 | return F; |
195 | MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo(); |
196 | validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef); |
197 | return nullptr; |
198 | } |
199 | |
200 | // Ensure there is no mismatch between actual and expected arg types: calls |
201 | // ahead of a processed definition. |
202 | void validateForwardCalls(const SPIRVSubtarget &STI, |
203 | MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR, |
204 | MachineInstr &FunDef) { |
205 | const Function *F = GR.getFunctionByDefinition(MI: &FunDef); |
206 | if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F)) |
207 | for (MachineInstr *FunCall : *FwdCalls) { |
208 | MachineRegisterInfo *CallMRI = |
209 | &FunCall->getParent()->getParent()->getRegInfo(); |
210 | validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall&: *FunCall, FunDef: &FunDef); |
211 | } |
212 | } |
213 | |
214 | // Validation of an access chain. |
215 | void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, |
216 | SPIRVGlobalRegistry &GR, MachineInstr &I) { |
217 | SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(VReg: I.getOperand(i: 0).getReg()); |
218 | if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) { |
219 | SPIRVType *BaseElemType = |
220 | GR.getSPIRVTypeForVReg(VReg: BaseTypeInst->getOperand(i: 2).getReg()); |
221 | validatePtrTypes(STI, MRI, GR, I, OpIdx: 2, ResType: BaseElemType); |
222 | } |
223 | } |
224 | |
225 | // TODO: the logic of inserting additional bitcast's is to be moved |
226 | // to pre-IRTranslation passes eventually |
227 | void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { |
228 | // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp) |
229 | // We'd like to avoid the needless second processing pass. |
230 | if (ProcessedMF.find(x: &MF) != ProcessedMF.end()) |
231 | return; |
232 | |
233 | MachineRegisterInfo *MRI = &MF.getRegInfo(); |
234 | SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry(); |
235 | GR.setCurrentFunc(MF); |
236 | for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) { |
237 | MachineBasicBlock *MBB = &*I; |
238 | for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end(); |
239 | MBBI != MBBE;) { |
240 | MachineInstr &MI = *MBBI++; |
241 | switch (MI.getOpcode()) { |
242 | case SPIRV::OpAtomicLoad: |
243 | case SPIRV::OpAtomicExchange: |
244 | case SPIRV::OpAtomicCompareExchange: |
245 | case SPIRV::OpAtomicCompareExchangeWeak: |
246 | case SPIRV::OpAtomicIIncrement: |
247 | case SPIRV::OpAtomicIDecrement: |
248 | case SPIRV::OpAtomicIAdd: |
249 | case SPIRV::OpAtomicISub: |
250 | case SPIRV::OpAtomicSMin: |
251 | case SPIRV::OpAtomicUMin: |
252 | case SPIRV::OpAtomicSMax: |
253 | case SPIRV::OpAtomicUMax: |
254 | case SPIRV::OpAtomicAnd: |
255 | case SPIRV::OpAtomicOr: |
256 | case SPIRV::OpAtomicXor: |
257 | // for the above listed instructions |
258 | // OpAtomicXXX <ResType>, ptr %Op, ... |
259 | // implies that %Op is a pointer to <ResType> |
260 | case SPIRV::OpLoad: |
261 | // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType> |
262 | validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 2, |
263 | ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 0).getReg())); |
264 | break; |
265 | case SPIRV::OpAtomicStore: |
266 | // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj> |
267 | // implies that %Op points to the <Obj>'s type |
268 | validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 0, |
269 | ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 3).getReg())); |
270 | break; |
271 | case SPIRV::OpStore: |
272 | // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type |
273 | validatePtrTypes(STI, MRI, GR, I&: MI, OpIdx: 0, |
274 | ResType: GR.getSPIRVTypeForVReg(VReg: MI.getOperand(i: 1).getReg())); |
275 | break; |
276 | case SPIRV::OpPtrCastToGeneric: |
277 | validateAccessChain(STI, MRI, GR, I&: MI); |
278 | break; |
279 | case SPIRV::OpInBoundsPtrAccessChain: |
280 | if (MI.getNumOperands() == 4) |
281 | validateAccessChain(STI, MRI, GR, I&: MI); |
282 | break; |
283 | |
284 | case SPIRV::OpFunctionCall: |
285 | // ensure there is no mismatch between actual and expected arg types: |
286 | // calls with a processed definition |
287 | if (MI.getNumOperands() > 3) |
288 | if (const Function *F = validateFunCall(STI, CallMRI: MRI, GR, FunCall&: MI)) |
289 | GR.addForwardCall(F, MI: &MI); |
290 | break; |
291 | case SPIRV::OpFunction: |
292 | // ensure there is no mismatch between actual and expected arg types: |
293 | // calls ahead of a processed definition |
294 | validateForwardCalls(STI, DefMRI: MRI, GR, FunDef&: MI); |
295 | break; |
296 | |
297 | // ensure that LLVM IR bitwise instructions result in logical SPIR-V |
298 | // instructions when applied to bool type |
299 | case SPIRV::OpBitwiseOrS: |
300 | case SPIRV::OpBitwiseOrV: |
301 | if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(), |
302 | SPIRV::OpTypeBool)) |
303 | MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr)); |
304 | break; |
305 | case SPIRV::OpBitwiseAndS: |
306 | case SPIRV::OpBitwiseAndV: |
307 | if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(), |
308 | SPIRV::OpTypeBool)) |
309 | MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd)); |
310 | break; |
311 | case SPIRV::OpBitwiseXorS: |
312 | case SPIRV::OpBitwiseXorV: |
313 | if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(), |
314 | SPIRV::OpTypeBool)) |
315 | MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual)); |
316 | break; |
317 | } |
318 | } |
319 | } |
320 | ProcessedMF.insert(x: &MF); |
321 | TargetLowering::finalizeLowering(MF); |
322 | } |
323 | |