1 | //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the targeting of the Machinelegalizer class for SPIR-V. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "SPIRVLegalizerInfo.h" |
14 | #include "SPIRV.h" |
15 | #include "SPIRVGlobalRegistry.h" |
16 | #include "SPIRVSubtarget.h" |
17 | #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" |
18 | #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" |
19 | #include "llvm/CodeGen/MachineInstr.h" |
20 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
21 | #include "llvm/CodeGen/TargetOpcodes.h" |
22 | |
23 | using namespace llvm; |
24 | using namespace llvm::LegalizeActions; |
25 | using namespace llvm::LegalityPredicates; |
26 | |
27 | static const std::set<unsigned> TypeFoldingSupportingOpcs = { |
28 | TargetOpcode::G_ADD, |
29 | TargetOpcode::G_FADD, |
30 | TargetOpcode::G_SUB, |
31 | TargetOpcode::G_FSUB, |
32 | TargetOpcode::G_MUL, |
33 | TargetOpcode::G_FMUL, |
34 | TargetOpcode::G_SDIV, |
35 | TargetOpcode::G_UDIV, |
36 | TargetOpcode::G_FDIV, |
37 | TargetOpcode::G_SREM, |
38 | TargetOpcode::G_UREM, |
39 | TargetOpcode::G_FREM, |
40 | TargetOpcode::G_FNEG, |
41 | TargetOpcode::G_CONSTANT, |
42 | TargetOpcode::G_FCONSTANT, |
43 | TargetOpcode::G_AND, |
44 | TargetOpcode::G_OR, |
45 | TargetOpcode::G_XOR, |
46 | TargetOpcode::G_SHL, |
47 | TargetOpcode::G_ASHR, |
48 | TargetOpcode::G_LSHR, |
49 | TargetOpcode::G_SELECT, |
50 | TargetOpcode::G_EXTRACT_VECTOR_ELT, |
51 | }; |
52 | |
53 | bool isTypeFoldingSupported(unsigned Opcode) { |
54 | return TypeFoldingSupportingOpcs.count(x: Opcode) > 0; |
55 | } |
56 | |
57 | SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { |
58 | using namespace TargetOpcode; |
59 | |
60 | this->ST = &ST; |
61 | GR = ST.getSPIRVGlobalRegistry(); |
62 | |
63 | const LLT s1 = LLT::scalar(SizeInBits: 1); |
64 | const LLT s8 = LLT::scalar(SizeInBits: 8); |
65 | const LLT s16 = LLT::scalar(SizeInBits: 16); |
66 | const LLT s32 = LLT::scalar(SizeInBits: 32); |
67 | const LLT s64 = LLT::scalar(SizeInBits: 64); |
68 | |
69 | const LLT v16s64 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 64); |
70 | const LLT v16s32 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 32); |
71 | const LLT v16s16 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 16); |
72 | const LLT v16s8 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 8); |
73 | const LLT v16s1 = LLT::fixed_vector(NumElements: 16, ScalarSizeInBits: 1); |
74 | |
75 | const LLT v8s64 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 64); |
76 | const LLT v8s32 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 32); |
77 | const LLT v8s16 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 16); |
78 | const LLT v8s8 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 8); |
79 | const LLT v8s1 = LLT::fixed_vector(NumElements: 8, ScalarSizeInBits: 1); |
80 | |
81 | const LLT v4s64 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 64); |
82 | const LLT v4s32 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 32); |
83 | const LLT v4s16 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 16); |
84 | const LLT v4s8 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 8); |
85 | const LLT v4s1 = LLT::fixed_vector(NumElements: 4, ScalarSizeInBits: 1); |
86 | |
87 | const LLT v3s64 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 64); |
88 | const LLT v3s32 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 32); |
89 | const LLT v3s16 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 16); |
90 | const LLT v3s8 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 8); |
91 | const LLT v3s1 = LLT::fixed_vector(NumElements: 3, ScalarSizeInBits: 1); |
92 | |
93 | const LLT v2s64 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 64); |
94 | const LLT v2s32 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 32); |
95 | const LLT v2s16 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 16); |
96 | const LLT v2s8 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 8); |
97 | const LLT v2s1 = LLT::fixed_vector(NumElements: 2, ScalarSizeInBits: 1); |
98 | |
99 | const unsigned PSize = ST.getPointerSize(); |
100 | const LLT p0 = LLT::pointer(AddressSpace: 0, SizeInBits: PSize); // Function |
101 | const LLT p1 = LLT::pointer(AddressSpace: 1, SizeInBits: PSize); // CrossWorkgroup |
102 | const LLT p2 = LLT::pointer(AddressSpace: 2, SizeInBits: PSize); // UniformConstant |
103 | const LLT p3 = LLT::pointer(AddressSpace: 3, SizeInBits: PSize); // Workgroup |
104 | const LLT p4 = LLT::pointer(AddressSpace: 4, SizeInBits: PSize); // Generic |
105 | const LLT p5 = |
106 | LLT::pointer(AddressSpace: 5, SizeInBits: PSize); // Input, SPV_INTEL_usm_storage_classes (Device) |
107 | const LLT p6 = LLT::pointer(AddressSpace: 6, SizeInBits: PSize); // SPV_INTEL_usm_storage_classes (Host) |
108 | |
109 | // TODO: remove copy-pasting here by using concatenation in some way. |
110 | auto allPtrsScalarsAndVectors = { |
111 | p0, p1, p2, p3, p4, p5, p6, s1, s8, s16, |
112 | s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, |
113 | v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16, |
114 | v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; |
115 | |
116 | auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, |
117 | v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, |
118 | v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, |
119 | v16s8, v16s16, v16s32, v16s64}; |
120 | |
121 | auto allScalarsAndVectors = { |
122 | s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, |
123 | v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, |
124 | v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; |
125 | |
126 | auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16, |
127 | v2s32, v2s64, v3s8, v3s16, v3s32, v3s64, |
128 | v4s8, v4s16, v4s32, v4s64, v8s8, v8s16, |
129 | v8s32, v8s64, v16s8, v16s16, v16s32, v16s64}; |
130 | |
131 | auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1}; |
132 | |
133 | auto allIntScalars = {s8, s16, s32, s64}; |
134 | |
135 | auto allFloatScalars = {s16, s32, s64}; |
136 | |
137 | auto allFloatScalarsAndVectors = { |
138 | s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, |
139 | v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; |
140 | |
141 | auto allFloatAndIntScalars = allIntScalars; |
142 | |
143 | auto allPtrs = {p0, p1, p2, p3, p4, p5, p6}; |
144 | auto allWritablePtrs = {p0, p1, p3, p4, p5, p6}; |
145 | |
146 | for (auto Opc : TypeFoldingSupportingOpcs) |
147 | getActionDefinitionsBuilder(Opcode: Opc).custom(); |
148 | |
149 | getActionDefinitionsBuilder(Opcode: G_GLOBAL_VALUE).alwaysLegal(); |
150 | |
151 | // TODO: add proper rules for vectors legalization. |
152 | getActionDefinitionsBuilder( |
153 | Opcodes: {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR}) |
154 | .alwaysLegal(); |
155 | |
156 | // Vector Reduction Operations |
157 | getActionDefinitionsBuilder( |
158 | Opcodes: {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX, |
159 | G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN, |
160 | G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM, |
161 | G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR}) |
162 | .legalFor(Types: allVectors) |
163 | .scalarize(TypeIdx: 1) |
164 | .lower(); |
165 | |
166 | getActionDefinitionsBuilder(Opcodes: {G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL}) |
167 | .scalarize(TypeIdx: 2) |
168 | .lower(); |
169 | |
170 | // Merge/Unmerge |
171 | // TODO: add proper legalization rules. |
172 | getActionDefinitionsBuilder(Opcode: G_UNMERGE_VALUES).alwaysLegal(); |
173 | |
174 | getActionDefinitionsBuilder(Opcodes: {G_MEMCPY, G_MEMMOVE}) |
175 | .legalIf(Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allWritablePtrs), P1: typeInSet(TypeIdx: 1, TypesInit: allPtrs))); |
176 | |
177 | getActionDefinitionsBuilder(Opcode: G_MEMSET).legalIf( |
178 | Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allWritablePtrs), P1: typeInSet(TypeIdx: 1, TypesInit: allIntScalars))); |
179 | |
180 | getActionDefinitionsBuilder(Opcode: G_ADDRSPACE_CAST) |
181 | .legalForCartesianProduct(Types0: allPtrs, Types1: allPtrs); |
182 | |
183 | getActionDefinitionsBuilder(Opcodes: {G_LOAD, G_STORE}).legalIf(Predicate: typeInSet(TypeIdx: 1, TypesInit: allPtrs)); |
184 | |
185 | getActionDefinitionsBuilder(Opcode: G_BITREVERSE).legalFor(Types: allFloatScalarsAndVectors); |
186 | |
187 | getActionDefinitionsBuilder(Opcode: G_FMA).legalFor(Types: allFloatScalarsAndVectors); |
188 | |
189 | getActionDefinitionsBuilder(Opcodes: {G_FPTOSI, G_FPTOUI}) |
190 | .legalForCartesianProduct(Types0: allIntScalarsAndVectors, |
191 | Types1: allFloatScalarsAndVectors); |
192 | |
193 | getActionDefinitionsBuilder(Opcodes: {G_SITOFP, G_UITOFP}) |
194 | .legalForCartesianProduct(Types0: allFloatScalarsAndVectors, |
195 | Types1: allScalarsAndVectors); |
196 | |
197 | getActionDefinitionsBuilder(Opcodes: {G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS}) |
198 | .legalFor(Types: allIntScalarsAndVectors); |
199 | |
200 | getActionDefinitionsBuilder(Opcode: G_CTPOP).legalForCartesianProduct( |
201 | Types0: allIntScalarsAndVectors, Types1: allIntScalarsAndVectors); |
202 | |
203 | getActionDefinitionsBuilder(Opcode: G_PHI).legalFor(Types: allPtrsScalarsAndVectors); |
204 | |
205 | getActionDefinitionsBuilder(Opcode: G_BITCAST).legalIf( |
206 | Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allPtrsScalarsAndVectors), |
207 | P1: typeInSet(TypeIdx: 1, TypesInit: allPtrsScalarsAndVectors))); |
208 | |
209 | getActionDefinitionsBuilder(Opcodes: {G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal(); |
210 | |
211 | getActionDefinitionsBuilder(Opcodes: {G_STACKSAVE, G_STACKRESTORE}).alwaysLegal(); |
212 | |
213 | getActionDefinitionsBuilder(Opcode: G_INTTOPTR) |
214 | .legalForCartesianProduct(Types0: allPtrs, Types1: allIntScalars); |
215 | getActionDefinitionsBuilder(Opcode: G_PTRTOINT) |
216 | .legalForCartesianProduct(Types0: allIntScalars, Types1: allPtrs); |
217 | getActionDefinitionsBuilder(Opcode: G_PTR_ADD).legalForCartesianProduct( |
218 | Types0: allPtrs, Types1: allIntScalars); |
219 | |
220 | // ST.canDirectlyComparePointers() for pointer args is supported in |
221 | // legalizeCustom(). |
222 | getActionDefinitionsBuilder(Opcode: G_ICMP).customIf( |
223 | Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allBoolScalarsAndVectors), |
224 | P1: typeInSet(TypeIdx: 1, TypesInit: allPtrsScalarsAndVectors))); |
225 | |
226 | getActionDefinitionsBuilder(Opcode: G_FCMP).legalIf( |
227 | Predicate: all(P0: typeInSet(TypeIdx: 0, TypesInit: allBoolScalarsAndVectors), |
228 | P1: typeInSet(TypeIdx: 1, TypesInit: allFloatScalarsAndVectors))); |
229 | |
230 | getActionDefinitionsBuilder(Opcodes: {G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, |
231 | G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, |
232 | G_ATOMICRMW_SUB, G_ATOMICRMW_XOR, |
233 | G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN}) |
234 | .legalForCartesianProduct(Types0: allIntScalars, Types1: allWritablePtrs); |
235 | |
236 | getActionDefinitionsBuilder( |
237 | Opcodes: {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX}) |
238 | .legalForCartesianProduct(Types0: allFloatScalars, Types1: allWritablePtrs); |
239 | |
240 | getActionDefinitionsBuilder(Opcode: G_ATOMICRMW_XCHG) |
241 | .legalForCartesianProduct(Types0: allFloatAndIntScalars, Types1: allWritablePtrs); |
242 | |
243 | getActionDefinitionsBuilder(Opcode: G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower(); |
244 | // TODO: add proper legalization rules. |
245 | getActionDefinitionsBuilder(Opcode: G_ATOMIC_CMPXCHG).alwaysLegal(); |
246 | |
247 | getActionDefinitionsBuilder(Opcodes: {G_UADDO, G_USUBO, G_SMULO, G_UMULO}) |
248 | .alwaysLegal(); |
249 | |
250 | // Extensions. |
251 | getActionDefinitionsBuilder(Opcodes: {G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT}) |
252 | .legalForCartesianProduct(Types: allScalarsAndVectors); |
253 | |
254 | // FP conversions. |
255 | getActionDefinitionsBuilder(Opcodes: {G_FPTRUNC, G_FPEXT}) |
256 | .legalForCartesianProduct(Types: allFloatScalarsAndVectors); |
257 | |
258 | // Pointer-handling. |
259 | getActionDefinitionsBuilder(Opcode: G_FRAME_INDEX).legalFor(Types: {p0}); |
260 | |
261 | // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32. |
262 | getActionDefinitionsBuilder(Opcode: G_BRCOND).legalFor(Types: {s1, s32}); |
263 | |
264 | // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to |
265 | // tighten these requirements. Many of these math functions are only legal on |
266 | // specific bitwidths, so they are not selectable for |
267 | // allFloatScalarsAndVectors. |
268 | getActionDefinitionsBuilder(Opcodes: {G_FPOW, |
269 | G_FEXP, |
270 | G_FEXP2, |
271 | G_FLOG, |
272 | G_FLOG2, |
273 | G_FLOG10, |
274 | G_FABS, |
275 | G_FMINNUM, |
276 | G_FMAXNUM, |
277 | G_FCEIL, |
278 | G_FCOS, |
279 | G_FSIN, |
280 | G_FSQRT, |
281 | G_FFLOOR, |
282 | G_FRINT, |
283 | G_FNEARBYINT, |
284 | G_INTRINSIC_ROUND, |
285 | G_INTRINSIC_TRUNC, |
286 | G_FMINIMUM, |
287 | G_FMAXIMUM, |
288 | G_INTRINSIC_ROUNDEVEN}) |
289 | .legalFor(Types: allFloatScalarsAndVectors); |
290 | |
291 | getActionDefinitionsBuilder(Opcode: G_FCOPYSIGN) |
292 | .legalForCartesianProduct(Types0: allFloatScalarsAndVectors, |
293 | Types1: allFloatScalarsAndVectors); |
294 | |
295 | getActionDefinitionsBuilder(Opcode: G_FPOWI).legalForCartesianProduct( |
296 | Types0: allFloatScalarsAndVectors, Types1: allIntScalarsAndVectors); |
297 | |
298 | if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) { |
299 | getActionDefinitionsBuilder( |
300 | Opcodes: {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF}) |
301 | .legalForCartesianProduct(Types0: allIntScalarsAndVectors, |
302 | Types1: allIntScalarsAndVectors); |
303 | |
304 | // Struct return types become a single scalar, so cannot easily legalize. |
305 | getActionDefinitionsBuilder(Opcodes: {G_SMULH, G_UMULH}).alwaysLegal(); |
306 | } |
307 | |
308 | getLegacyLegalizerInfo().computeTables(); |
309 | verify(*ST.getInstrInfo()); |
310 | } |
311 | |
312 | static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, |
313 | LegalizerHelper &Helper, |
314 | MachineRegisterInfo &MRI, |
315 | SPIRVGlobalRegistry *GR) { |
316 | Register ConvReg = MRI.createGenericVirtualRegister(Ty: ConvTy); |
317 | GR->assignSPIRVTypeToVReg(Type: SpirvType, VReg: ConvReg, MF&: Helper.MIRBuilder.getMF()); |
318 | Helper.MIRBuilder.buildInstr(Opcode: TargetOpcode::G_PTRTOINT) |
319 | .addDef(RegNo: ConvReg) |
320 | .addUse(RegNo: Reg); |
321 | return ConvReg; |
322 | } |
323 | |
324 | bool SPIRVLegalizerInfo::legalizeCustom( |
325 | LegalizerHelper &Helper, MachineInstr &MI, |
326 | LostDebugLocObserver &LocObserver) const { |
327 | auto Opc = MI.getOpcode(); |
328 | MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); |
329 | if (!isTypeFoldingSupported(Opcode: Opc)) { |
330 | assert(Opc == TargetOpcode::G_ICMP); |
331 | assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg())); |
332 | auto &Op0 = MI.getOperand(i: 2); |
333 | auto &Op1 = MI.getOperand(i: 3); |
334 | Register Reg0 = Op0.getReg(); |
335 | Register Reg1 = Op1.getReg(); |
336 | CmpInst::Predicate Cond = |
337 | static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate()); |
338 | if ((!ST->canDirectlyComparePointers() || |
339 | (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) && |
340 | MRI.getType(Reg: Reg0).isPointer() && MRI.getType(Reg: Reg1).isPointer()) { |
341 | LLT ConvT = LLT::scalar(SizeInBits: ST->getPointerSize()); |
342 | Type *LLVMTy = IntegerType::get(C&: MI.getMF()->getFunction().getContext(), |
343 | NumBits: ST->getPointerSize()); |
344 | SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder); |
345 | Op0.setReg(convertPtrToInt(Reg: Reg0, ConvTy: ConvT, SpirvType: SpirvTy, Helper, MRI, GR)); |
346 | Op1.setReg(convertPtrToInt(Reg: Reg1, ConvTy: ConvT, SpirvType: SpirvTy, Helper, MRI, GR)); |
347 | } |
348 | return true; |
349 | } |
350 | // TODO: implement legalization for other opcodes. |
351 | return true; |
352 | } |
353 | |