1//===-- SPIRVDuplicatesTracker.h - SPIR-V Duplicates Tracker ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// General infrastructure for keeping track of the values that according to
10// the SPIR-V binary layout should be global to the whole module.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
15#define LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
16
17#include "MCTargetDesc/SPIRVBaseInfo.h"
18#include "MCTargetDesc/SPIRVMCTargetDesc.h"
19#include "llvm/ADT/DenseMap.h"
20#include "llvm/ADT/MapVector.h"
21#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
22#include "llvm/CodeGen/MachineModuleInfo.h"
23
24#include <type_traits>
25
26namespace llvm {
27namespace SPIRV {
28// NOTE: using MapVector instead of DenseMap because it helps getting
29// everything ordered in a stable manner for a price of extra (NumKeys)*PtrSize
30// memory and expensive removals which do not happen anyway.
31class DTSortableEntry : public MapVector<const MachineFunction *, Register> {
32 SmallVector<DTSortableEntry *, 2> Deps;
33
34 struct FlagsTy {
35 unsigned IsFunc : 1;
36 unsigned IsGV : 1;
37 // NOTE: bit-field default init is a C++20 feature.
38 FlagsTy() : IsFunc(0), IsGV(0) {}
39 };
40 FlagsTy Flags;
41
42public:
43 // Common hoisting utility doesn't support function, because their hoisting
44 // require hoisting of params as well.
45 bool getIsFunc() const { return Flags.IsFunc; }
46 bool getIsGV() const { return Flags.IsGV; }
47 void setIsFunc(bool V) { Flags.IsFunc = V; }
48 void setIsGV(bool V) { Flags.IsGV = V; }
49
50 const SmallVector<DTSortableEntry *, 2> &getDeps() const { return Deps; }
51 void addDep(DTSortableEntry *E) { Deps.push_back(Elt: E); }
52};
53
54struct SpecialTypeDescriptor {
55 enum SpecialTypeKind {
56 STK_Empty = 0,
57 STK_Image,
58 STK_SampledImage,
59 STK_Sampler,
60 STK_Pipe,
61 STK_DeviceEvent,
62 STK_Pointer,
63 STK_Last = -1
64 };
65 SpecialTypeKind Kind;
66
67 unsigned Hash;
68
69 SpecialTypeDescriptor() = delete;
70 SpecialTypeDescriptor(SpecialTypeKind K) : Kind(K) { Hash = Kind; }
71
72 unsigned getHash() const { return Hash; }
73
74 virtual ~SpecialTypeDescriptor() {}
75};
76
77struct ImageTypeDescriptor : public SpecialTypeDescriptor {
78 union ImageAttrs {
79 struct BitFlags {
80 unsigned Dim : 3;
81 unsigned Depth : 2;
82 unsigned Arrayed : 1;
83 unsigned MS : 1;
84 unsigned Sampled : 2;
85 unsigned ImageFormat : 6;
86 unsigned AQ : 2;
87 } Flags;
88 unsigned Val;
89 };
90
91 ImageTypeDescriptor(const Type *SampledTy, unsigned Dim, unsigned Depth,
92 unsigned Arrayed, unsigned MS, unsigned Sampled,
93 unsigned ImageFormat, unsigned AQ = 0)
94 : SpecialTypeDescriptor(SpecialTypeKind::STK_Image) {
95 ImageAttrs Attrs;
96 Attrs.Val = 0;
97 Attrs.Flags.Dim = Dim;
98 Attrs.Flags.Depth = Depth;
99 Attrs.Flags.Arrayed = Arrayed;
100 Attrs.Flags.MS = MS;
101 Attrs.Flags.Sampled = Sampled;
102 Attrs.Flags.ImageFormat = ImageFormat;
103 Attrs.Flags.AQ = AQ;
104 Hash = (DenseMapInfo<Type *>().getHashValue(PtrVal: SampledTy) & 0xffff) ^
105 ((Attrs.Val << 8) | Kind);
106 }
107
108 static bool classof(const SpecialTypeDescriptor *TD) {
109 return TD->Kind == SpecialTypeKind::STK_Image;
110 }
111};
112
113struct SampledImageTypeDescriptor : public SpecialTypeDescriptor {
114 SampledImageTypeDescriptor(const Type *SampledTy, const MachineInstr *ImageTy)
115 : SpecialTypeDescriptor(SpecialTypeKind::STK_SampledImage) {
116 assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
117 ImageTypeDescriptor TD(
118 SampledTy, ImageTy->getOperand(i: 2).getImm(),
119 ImageTy->getOperand(i: 3).getImm(), ImageTy->getOperand(i: 4).getImm(),
120 ImageTy->getOperand(i: 5).getImm(), ImageTy->getOperand(i: 6).getImm(),
121 ImageTy->getOperand(i: 7).getImm(), ImageTy->getOperand(i: 8).getImm());
122 Hash = TD.getHash() ^ Kind;
123 }
124
125 static bool classof(const SpecialTypeDescriptor *TD) {
126 return TD->Kind == SpecialTypeKind::STK_SampledImage;
127 }
128};
129
130struct SamplerTypeDescriptor : public SpecialTypeDescriptor {
131 SamplerTypeDescriptor()
132 : SpecialTypeDescriptor(SpecialTypeKind::STK_Sampler) {
133 Hash = Kind;
134 }
135
136 static bool classof(const SpecialTypeDescriptor *TD) {
137 return TD->Kind == SpecialTypeKind::STK_Sampler;
138 }
139};
140
141struct PipeTypeDescriptor : public SpecialTypeDescriptor {
142
143 PipeTypeDescriptor(uint8_t AQ)
144 : SpecialTypeDescriptor(SpecialTypeKind::STK_Pipe) {
145 Hash = (AQ << 8) | Kind;
146 }
147
148 static bool classof(const SpecialTypeDescriptor *TD) {
149 return TD->Kind == SpecialTypeKind::STK_Pipe;
150 }
151};
152
153struct DeviceEventTypeDescriptor : public SpecialTypeDescriptor {
154
155 DeviceEventTypeDescriptor()
156 : SpecialTypeDescriptor(SpecialTypeKind::STK_DeviceEvent) {
157 Hash = Kind;
158 }
159
160 static bool classof(const SpecialTypeDescriptor *TD) {
161 return TD->Kind == SpecialTypeKind::STK_DeviceEvent;
162 }
163};
164
165struct PointerTypeDescriptor : public SpecialTypeDescriptor {
166 const Type *ElementType;
167 unsigned AddressSpace;
168
169 PointerTypeDescriptor() = delete;
170 PointerTypeDescriptor(const Type *ElementType, unsigned AddressSpace)
171 : SpecialTypeDescriptor(SpecialTypeKind::STK_Pointer),
172 ElementType(ElementType), AddressSpace(AddressSpace) {
173 Hash = (DenseMapInfo<Type *>().getHashValue(PtrVal: ElementType) & 0xffff) ^
174 ((AddressSpace << 8) | Kind);
175 }
176
177 static bool classof(const SpecialTypeDescriptor *TD) {
178 return TD->Kind == SpecialTypeKind::STK_Pointer;
179 }
180};
181} // namespace SPIRV
182
183template <> struct DenseMapInfo<SPIRV::SpecialTypeDescriptor> {
184 static inline SPIRV::SpecialTypeDescriptor getEmptyKey() {
185 return SPIRV::SpecialTypeDescriptor(
186 SPIRV::SpecialTypeDescriptor::STK_Empty);
187 }
188 static inline SPIRV::SpecialTypeDescriptor getTombstoneKey() {
189 return SPIRV::SpecialTypeDescriptor(SPIRV::SpecialTypeDescriptor::STK_Last);
190 }
191 static unsigned getHashValue(SPIRV::SpecialTypeDescriptor Val) {
192 return Val.getHash();
193 }
194 static bool isEqual(SPIRV::SpecialTypeDescriptor LHS,
195 SPIRV::SpecialTypeDescriptor RHS) {
196 return getHashValue(Val: LHS) == getHashValue(Val: RHS);
197 }
198};
199
200template <typename KeyTy> class SPIRVDuplicatesTrackerBase {
201public:
202 // NOTE: using MapVector instead of DenseMap helps getting everything ordered
203 // in a stable manner for a price of extra (NumKeys)*PtrSize memory and
204 // expensive removals which don't happen anyway.
205 using StorageTy = MapVector<KeyTy, SPIRV::DTSortableEntry>;
206
207private:
208 StorageTy Storage;
209
210public:
211 void add(KeyTy V, const MachineFunction *MF, Register R) {
212 if (find(V, MF).isValid())
213 return;
214
215 Storage[V][MF] = R;
216 if (std::is_same<Function,
217 typename std::remove_const<
218 typename std::remove_pointer<KeyTy>::type>::type>() ||
219 std::is_same<Argument,
220 typename std::remove_const<
221 typename std::remove_pointer<KeyTy>::type>::type>())
222 Storage[V].setIsFunc(true);
223 if (std::is_same<GlobalVariable,
224 typename std::remove_const<
225 typename std::remove_pointer<KeyTy>::type>::type>())
226 Storage[V].setIsGV(true);
227 }
228
229 Register find(KeyTy V, const MachineFunction *MF) const {
230 auto iter = Storage.find(V);
231 if (iter != Storage.end()) {
232 auto Map = iter->second;
233 auto iter2 = Map.find(MF);
234 if (iter2 != Map.end())
235 return iter2->second;
236 }
237 return Register();
238 }
239
240 const StorageTy &getAllUses() const { return Storage; }
241
242private:
243 StorageTy &getAllUses() { return Storage; }
244
245 // The friend class needs to have access to the internal storage
246 // to be able to build dependency graph, can't declare only one
247 // function a 'friend' due to the incomplete declaration at this point
248 // and mutual dependency problems.
249 friend class SPIRVGeneralDuplicatesTracker;
250};
251
252template <typename T>
253class SPIRVDuplicatesTracker : public SPIRVDuplicatesTrackerBase<const T *> {};
254
255template <>
256class SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor>
257 : public SPIRVDuplicatesTrackerBase<SPIRV::SpecialTypeDescriptor> {};
258
259class SPIRVGeneralDuplicatesTracker {
260 SPIRVDuplicatesTracker<Type> TT;
261 SPIRVDuplicatesTracker<Constant> CT;
262 SPIRVDuplicatesTracker<GlobalVariable> GT;
263 SPIRVDuplicatesTracker<Function> FT;
264 SPIRVDuplicatesTracker<Argument> AT;
265 SPIRVDuplicatesTracker<MachineInstr> MT;
266 SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
267
268 // NOTE: using MOs instead of regs to get rid of MF dependency to be able
269 // to use flat data structure.
270 // NOTE: replacing DenseMap with MapVector doesn't affect overall correctness
271 // but makes LITs more stable, should prefer DenseMap still due to
272 // significant perf difference.
273 using SPIRVReg2EntryTy =
274 MapVector<MachineOperand *, SPIRV::DTSortableEntry *>;
275
276 template <typename T>
277 void prebuildReg2Entry(SPIRVDuplicatesTracker<T> &DT,
278 SPIRVReg2EntryTy &Reg2Entry);
279
280public:
281 void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph,
282 MachineModuleInfo *MMI);
283
284 void add(const Type *Ty, const MachineFunction *MF, Register R) {
285 TT.add(V: Ty, MF, R);
286 }
287
288 void add(const Type *PointerElementType, unsigned AddressSpace,
289 const MachineFunction *MF, Register R) {
290 ST.add(V: SPIRV::PointerTypeDescriptor(PointerElementType, AddressSpace), MF,
291 R);
292 }
293
294 void add(const Constant *C, const MachineFunction *MF, Register R) {
295 CT.add(V: C, MF, R);
296 }
297
298 void add(const GlobalVariable *GV, const MachineFunction *MF, Register R) {
299 GT.add(V: GV, MF, R);
300 }
301
302 void add(const Function *F, const MachineFunction *MF, Register R) {
303 FT.add(V: F, MF, R);
304 }
305
306 void add(const Argument *Arg, const MachineFunction *MF, Register R) {
307 AT.add(V: Arg, MF, R);
308 }
309
310 void add(const MachineInstr *MI, const MachineFunction *MF, Register R) {
311 MT.add(V: MI, MF, R);
312 }
313
314 void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF,
315 Register R) {
316 ST.add(V: TD, MF, R);
317 }
318
319 Register find(const Type *Ty, const MachineFunction *MF) {
320 return TT.find(V: const_cast<Type *>(Ty), MF);
321 }
322
323 Register find(const Type *PointerElementType, unsigned AddressSpace,
324 const MachineFunction *MF) {
325 return ST.find(
326 V: SPIRV::PointerTypeDescriptor(PointerElementType, AddressSpace), MF);
327 }
328
329 Register find(const Constant *C, const MachineFunction *MF) {
330 return CT.find(V: const_cast<Constant *>(C), MF);
331 }
332
333 Register find(const GlobalVariable *GV, const MachineFunction *MF) {
334 return GT.find(V: const_cast<GlobalVariable *>(GV), MF);
335 }
336
337 Register find(const Function *F, const MachineFunction *MF) {
338 return FT.find(V: const_cast<Function *>(F), MF);
339 }
340
341 Register find(const Argument *Arg, const MachineFunction *MF) {
342 return AT.find(V: const_cast<Argument *>(Arg), MF);
343 }
344
345 Register find(const MachineInstr *MI, const MachineFunction *MF) {
346 return MT.find(V: const_cast<MachineInstr *>(MI), MF);
347 }
348
349 Register find(const SPIRV::SpecialTypeDescriptor &TD,
350 const MachineFunction *MF) {
351 return ST.find(V: TD, MF);
352 }
353
354 const SPIRVDuplicatesTracker<Type> *getTypes() { return &TT; }
355};
356} // namespace llvm
357#endif // LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
358

source code of llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h