1 | //===- DXContainer.h - DXContainer file implementation ----------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file declares the DXContainerFile class, which implements the ObjectFile |
10 | // interface for DXContainer files. |
11 | // |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #ifndef LLVM_OBJECT_DXCONTAINER_H |
16 | #define LLVM_OBJECT_DXCONTAINER_H |
17 | |
18 | #include "llvm/ADT/SmallVector.h" |
19 | #include "llvm/ADT/StringRef.h" |
20 | #include "llvm/BinaryFormat/DXContainer.h" |
21 | #include "llvm/Support/Error.h" |
22 | #include "llvm/Support/MemoryBufferRef.h" |
23 | #include "llvm/TargetParser/Triple.h" |
24 | #include <array> |
25 | #include <variant> |
26 | |
27 | namespace llvm { |
28 | namespace object { |
29 | |
30 | namespace detail { |
31 | template <typename T> |
32 | std::enable_if_t<std::is_arithmetic<T>::value, void> swapBytes(T &value) { |
33 | sys::swapByteOrder(value); |
34 | } |
35 | |
36 | template <typename T> |
37 | std::enable_if_t<std::is_class<T>::value, void> swapBytes(T &value) { |
38 | value.swapBytes(); |
39 | } |
40 | } // namespace detail |
41 | |
42 | // This class provides a view into the underlying resource array. The Resource |
43 | // data is little-endian encoded and may not be properly aligned to read |
44 | // directly from. The dereference operator creates a copy of the data and byte |
45 | // swaps it as appropriate. |
46 | template <typename T> struct ViewArray { |
47 | StringRef Data; |
48 | uint32_t Stride = sizeof(T); // size of each element in the list. |
49 | |
50 | ViewArray() = default; |
51 | ViewArray(StringRef D, size_t S) : Data(D), Stride(S) {} |
52 | |
53 | using value_type = T; |
54 | static constexpr uint32_t MaxStride() { |
55 | return static_cast<uint32_t>(sizeof(value_type)); |
56 | } |
57 | |
58 | struct iterator { |
59 | StringRef Data; |
60 | uint32_t Stride; // size of each element in the list. |
61 | const char *Current; |
62 | |
63 | iterator(const ViewArray &A, const char *C) |
64 | : Data(A.Data), Stride(A.Stride), Current(C) {} |
65 | iterator(const iterator &) = default; |
66 | |
67 | value_type operator*() { |
68 | // Explicitly zero the structure so that unused fields are zeroed. It is |
69 | // up to the user to know if the fields are used by verifying the PSV |
70 | // version. |
71 | value_type Val; |
72 | std::memset(s: &Val, c: 0, n: sizeof(value_type)); |
73 | if (Current >= Data.end()) |
74 | return Val; |
75 | memcpy(dest: static_cast<void *>(&Val), src: Current, n: std::min(a: Stride, b: MaxStride())); |
76 | if (sys::IsBigEndianHost) |
77 | detail::swapBytes(Val); |
78 | return Val; |
79 | } |
80 | |
81 | iterator operator++() { |
82 | if (Current < Data.end()) |
83 | Current += Stride; |
84 | return *this; |
85 | } |
86 | |
87 | iterator operator++(int) { |
88 | iterator Tmp = *this; |
89 | ++*this; |
90 | return Tmp; |
91 | } |
92 | |
93 | iterator operator--() { |
94 | if (Current > Data.begin()) |
95 | Current -= Stride; |
96 | return *this; |
97 | } |
98 | |
99 | iterator operator--(int) { |
100 | iterator Tmp = *this; |
101 | --*this; |
102 | return Tmp; |
103 | } |
104 | |
105 | bool operator==(const iterator I) { return I.Current == Current; } |
106 | bool operator!=(const iterator I) { return !(*this == I); } |
107 | }; |
108 | |
109 | iterator begin() const { return iterator(*this, Data.begin()); } |
110 | |
111 | iterator end() const { return iterator(*this, Data.end()); } |
112 | |
113 | size_t size() const { return Data.size() / Stride; } |
114 | |
115 | bool isEmpty() const { return Data.empty(); } |
116 | }; |
117 | |
118 | namespace DirectX { |
119 | class PSVRuntimeInfo { |
120 | |
121 | using ResourceArray = ViewArray<dxbc::PSV::v2::ResourceBindInfo>; |
122 | using SigElementArray = ViewArray<dxbc::PSV::v0::SignatureElement>; |
123 | |
124 | StringRef Data; |
125 | uint32_t Size; |
126 | using InfoStruct = |
127 | std::variant<std::monostate, dxbc::PSV::v0::RuntimeInfo, |
128 | dxbc::PSV::v1::RuntimeInfo, dxbc::PSV::v2::RuntimeInfo, |
129 | dxbc::PSV::v3::RuntimeInfo>; |
130 | InfoStruct BasicInfo; |
131 | ResourceArray Resources; |
132 | StringRef StringTable; |
133 | SmallVector<uint32_t> SemanticIndexTable; |
134 | SigElementArray SigInputElements; |
135 | SigElementArray SigOutputElements; |
136 | SigElementArray SigPatchOrPrimElements; |
137 | |
138 | std::array<ViewArray<uint32_t>, 4> OutputVectorMasks; |
139 | ViewArray<uint32_t> PatchOrPrimMasks; |
140 | std::array<ViewArray<uint32_t>, 4> InputOutputMap; |
141 | ViewArray<uint32_t> InputPatchMap; |
142 | ViewArray<uint32_t> PatchOutputMap; |
143 | |
144 | public: |
145 | PSVRuntimeInfo(StringRef D) : Data(D), Size(0) {} |
146 | |
147 | // Parsing depends on the shader kind |
148 | Error parse(uint16_t ShaderKind); |
149 | |
150 | uint32_t getSize() const { return Size; } |
151 | uint32_t getResourceCount() const { return Resources.size(); } |
152 | ResourceArray getResources() const { return Resources; } |
153 | |
154 | uint32_t getVersion() const { |
155 | return Size >= sizeof(dxbc::PSV::v3::RuntimeInfo) |
156 | ? 3 |
157 | : (Size >= sizeof(dxbc::PSV::v2::RuntimeInfo) ? 2 |
158 | : (Size >= sizeof(dxbc::PSV::v1::RuntimeInfo)) ? 1 |
159 | : 0); |
160 | } |
161 | |
162 | uint32_t getResourceStride() const { return Resources.Stride; } |
163 | |
164 | const InfoStruct &getInfo() const { return BasicInfo; } |
165 | |
166 | template <typename T> const T *getInfoAs() const { |
167 | if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(ptr: &BasicInfo)) |
168 | return static_cast<const T *>(P); |
169 | if (std::is_same<T, dxbc::PSV::v3::RuntimeInfo>::value) |
170 | return nullptr; |
171 | |
172 | if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(ptr: &BasicInfo)) |
173 | return static_cast<const T *>(P); |
174 | if (std::is_same<T, dxbc::PSV::v2::RuntimeInfo>::value) |
175 | return nullptr; |
176 | |
177 | if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(ptr: &BasicInfo)) |
178 | return static_cast<const T *>(P); |
179 | if (std::is_same<T, dxbc::PSV::v1::RuntimeInfo>::value) |
180 | return nullptr; |
181 | |
182 | if (const auto *P = std::get_if<dxbc::PSV::v0::RuntimeInfo>(ptr: &BasicInfo)) |
183 | return static_cast<const T *>(P); |
184 | return nullptr; |
185 | } |
186 | |
187 | StringRef getStringTable() const { return StringTable; } |
188 | ArrayRef<uint32_t> getSemanticIndexTable() const { |
189 | return SemanticIndexTable; |
190 | } |
191 | |
192 | uint8_t getSigInputCount() const; |
193 | uint8_t getSigOutputCount() const; |
194 | uint8_t getSigPatchOrPrimCount() const; |
195 | |
196 | SigElementArray getSigInputElements() const { return SigInputElements; } |
197 | SigElementArray getSigOutputElements() const { return SigOutputElements; } |
198 | SigElementArray getSigPatchOrPrimElements() const { |
199 | return SigPatchOrPrimElements; |
200 | } |
201 | |
202 | ViewArray<uint32_t> getOutputVectorMasks(size_t Idx) const { |
203 | assert(Idx < 4); |
204 | return OutputVectorMasks[Idx]; |
205 | } |
206 | |
207 | ViewArray<uint32_t> getPatchOrPrimMasks() const { return PatchOrPrimMasks; } |
208 | |
209 | ViewArray<uint32_t> getInputOutputMap(size_t Idx) const { |
210 | assert(Idx < 4); |
211 | return InputOutputMap[Idx]; |
212 | } |
213 | |
214 | ViewArray<uint32_t> getInputPatchMap() const { return InputPatchMap; } |
215 | ViewArray<uint32_t> getPatchOutputMap() const { return PatchOutputMap; } |
216 | |
217 | uint32_t getSigElementStride() const { return SigInputElements.Stride; } |
218 | |
219 | bool usesViewID() const { |
220 | if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>()) |
221 | return P->UsesViewID != 0; |
222 | return false; |
223 | } |
224 | |
225 | uint8_t getInputVectorCount() const { |
226 | if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>()) |
227 | return P->SigInputVectors; |
228 | return 0; |
229 | } |
230 | |
231 | ArrayRef<uint8_t> getOutputVectorCounts() const { |
232 | if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>()) |
233 | return ArrayRef<uint8_t>(P->SigOutputVectors); |
234 | return ArrayRef<uint8_t>(); |
235 | } |
236 | |
237 | uint8_t getPatchConstOrPrimVectorCount() const { |
238 | if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>()) |
239 | return P->GeomData.SigPatchConstOrPrimVectors; |
240 | return 0; |
241 | } |
242 | }; |
243 | |
244 | class Signature { |
245 | ViewArray<dxbc::ProgramSignatureElement> Parameters; |
246 | uint32_t StringTableOffset; |
247 | StringRef StringTable; |
248 | |
249 | public: |
250 | ViewArray<dxbc::ProgramSignatureElement>::iterator begin() const { |
251 | return Parameters.begin(); |
252 | } |
253 | |
254 | ViewArray<dxbc::ProgramSignatureElement>::iterator end() const { |
255 | return Parameters.end(); |
256 | } |
257 | |
258 | StringRef getName(uint32_t Offset) const { |
259 | assert(Offset >= StringTableOffset && |
260 | Offset < StringTableOffset + StringTable.size() && |
261 | "Offset out of range." ); |
262 | // Name offsets are from the start of the signature data, not from the start |
263 | // of the string table. The header encodes the start offset of the sting |
264 | // table, so we convert the offset here. |
265 | uint32_t TableOffset = Offset - StringTableOffset; |
266 | return StringTable.slice(Start: TableOffset, End: StringTable.find(C: '\0', From: TableOffset)); |
267 | } |
268 | |
269 | bool isEmpty() const { return Parameters.isEmpty(); } |
270 | |
271 | Error initialize(StringRef Part); |
272 | }; |
273 | |
274 | } // namespace DirectX |
275 | |
276 | class DXContainer { |
277 | public: |
278 | using DXILData = std::pair<dxbc::ProgramHeader, const char *>; |
279 | |
280 | private: |
281 | DXContainer(MemoryBufferRef O); |
282 | |
283 | MemoryBufferRef Data; |
284 | dxbc::Header ; |
285 | SmallVector<uint32_t, 4> PartOffsets; |
286 | std::optional<DXILData> DXIL; |
287 | std::optional<uint64_t> ShaderFeatureFlags; |
288 | std::optional<dxbc::ShaderHash> Hash; |
289 | std::optional<DirectX::PSVRuntimeInfo> PSVInfo; |
290 | DirectX::Signature InputSignature; |
291 | DirectX::Signature OutputSignature; |
292 | DirectX::Signature PatchConstantSignature; |
293 | |
294 | Error (); |
295 | Error parsePartOffsets(); |
296 | Error (StringRef Part); |
297 | Error parseShaderFeatureFlags(StringRef Part); |
298 | Error parseHash(StringRef Part); |
299 | Error parsePSVInfo(StringRef Part); |
300 | Error parseSignature(StringRef Part, DirectX::Signature &Array); |
301 | friend class PartIterator; |
302 | |
303 | public: |
304 | // The PartIterator is a wrapper around the iterator for the PartOffsets |
305 | // member of the DXContainer. It contains a refernce to the container, and the |
306 | // current iterator value, as well as storage for a parsed part header. |
307 | class PartIterator { |
308 | const DXContainer &Container; |
309 | SmallVectorImpl<uint32_t>::const_iterator OffsetIt; |
310 | struct PartData { |
311 | dxbc::PartHeader Part; |
312 | uint32_t Offset; |
313 | StringRef Data; |
314 | } IteratorState; |
315 | |
316 | friend class DXContainer; |
317 | |
318 | PartIterator(const DXContainer &C, |
319 | SmallVectorImpl<uint32_t>::const_iterator It) |
320 | : Container(C), OffsetIt(It) { |
321 | if (OffsetIt == Container.PartOffsets.end()) |
322 | updateIteratorImpl(Offset: Container.PartOffsets.back()); |
323 | else |
324 | updateIterator(); |
325 | } |
326 | |
327 | // Updates the iterator's state data. This results in copying the part |
328 | // header into the iterator and handling any required byte swapping. This is |
329 | // called when incrementing or decrementing the iterator. |
330 | void updateIterator() { |
331 | if (OffsetIt != Container.PartOffsets.end()) |
332 | updateIteratorImpl(Offset: *OffsetIt); |
333 | } |
334 | |
335 | // Implementation for updating the iterator state based on a specified |
336 | // offest. |
337 | void updateIteratorImpl(const uint32_t Offset); |
338 | |
339 | public: |
340 | PartIterator &operator++() { |
341 | if (OffsetIt == Container.PartOffsets.end()) |
342 | return *this; |
343 | ++OffsetIt; |
344 | updateIterator(); |
345 | return *this; |
346 | } |
347 | |
348 | PartIterator operator++(int) { |
349 | PartIterator Tmp = *this; |
350 | ++(*this); |
351 | return Tmp; |
352 | } |
353 | |
354 | bool operator==(const PartIterator &RHS) const { |
355 | return OffsetIt == RHS.OffsetIt; |
356 | } |
357 | |
358 | bool operator!=(const PartIterator &RHS) const { |
359 | return OffsetIt != RHS.OffsetIt; |
360 | } |
361 | |
362 | const PartData &operator*() { return IteratorState; } |
363 | const PartData *operator->() { return &IteratorState; } |
364 | }; |
365 | |
366 | PartIterator begin() const { |
367 | return PartIterator(*this, PartOffsets.begin()); |
368 | } |
369 | |
370 | PartIterator end() const { return PartIterator(*this, PartOffsets.end()); } |
371 | |
372 | StringRef getData() const { return Data.getBuffer(); } |
373 | static Expected<DXContainer> create(MemoryBufferRef Object); |
374 | |
375 | const dxbc::Header &() const { return Header; } |
376 | |
377 | const std::optional<DXILData> &getDXIL() const { return DXIL; } |
378 | |
379 | std::optional<uint64_t> getShaderFeatureFlags() const { |
380 | return ShaderFeatureFlags; |
381 | } |
382 | |
383 | std::optional<dxbc::ShaderHash> getShaderHash() const { return Hash; } |
384 | |
385 | const std::optional<DirectX::PSVRuntimeInfo> &getPSVInfo() const { |
386 | return PSVInfo; |
387 | }; |
388 | |
389 | const DirectX::Signature &getInputSignature() const { return InputSignature; } |
390 | const DirectX::Signature &getOutputSignature() const { |
391 | return OutputSignature; |
392 | } |
393 | const DirectX::Signature &getPatchConstantSignature() const { |
394 | return PatchConstantSignature; |
395 | } |
396 | }; |
397 | |
398 | } // namespace object |
399 | } // namespace llvm |
400 | |
401 | #endif // LLVM_OBJECT_DXCONTAINER_H |
402 | |