1//===- DXContainer.h - DXContainer file implementation ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file declares the DXContainerFile class, which implements the ObjectFile
10// interface for DXContainer files.
11//
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef LLVM_OBJECT_DXCONTAINER_H
16#define LLVM_OBJECT_DXCONTAINER_H
17
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/ADT/StringRef.h"
20#include "llvm/BinaryFormat/DXContainer.h"
21#include "llvm/Support/Error.h"
22#include "llvm/Support/MemoryBufferRef.h"
23#include "llvm/TargetParser/Triple.h"
24#include <array>
25#include <variant>
26
27namespace llvm {
28namespace object {
29
30namespace detail {
31template <typename T>
32std::enable_if_t<std::is_arithmetic<T>::value, void> swapBytes(T &value) {
33 sys::swapByteOrder(value);
34}
35
36template <typename T>
37std::enable_if_t<std::is_class<T>::value, void> swapBytes(T &value) {
38 value.swapBytes();
39}
40} // namespace detail
41
42// This class provides a view into the underlying resource array. The Resource
43// data is little-endian encoded and may not be properly aligned to read
44// directly from. The dereference operator creates a copy of the data and byte
45// swaps it as appropriate.
46template <typename T> struct ViewArray {
47 StringRef Data;
48 uint32_t Stride = sizeof(T); // size of each element in the list.
49
50 ViewArray() = default;
51 ViewArray(StringRef D, size_t S) : Data(D), Stride(S) {}
52
53 using value_type = T;
54 static constexpr uint32_t MaxStride() {
55 return static_cast<uint32_t>(sizeof(value_type));
56 }
57
58 struct iterator {
59 StringRef Data;
60 uint32_t Stride; // size of each element in the list.
61 const char *Current;
62
63 iterator(const ViewArray &A, const char *C)
64 : Data(A.Data), Stride(A.Stride), Current(C) {}
65 iterator(const iterator &) = default;
66
67 value_type operator*() {
68 // Explicitly zero the structure so that unused fields are zeroed. It is
69 // up to the user to know if the fields are used by verifying the PSV
70 // version.
71 value_type Val;
72 std::memset(s: &Val, c: 0, n: sizeof(value_type));
73 if (Current >= Data.end())
74 return Val;
75 memcpy(dest: static_cast<void *>(&Val), src: Current, n: std::min(a: Stride, b: MaxStride()));
76 if (sys::IsBigEndianHost)
77 detail::swapBytes(Val);
78 return Val;
79 }
80
81 iterator operator++() {
82 if (Current < Data.end())
83 Current += Stride;
84 return *this;
85 }
86
87 iterator operator++(int) {
88 iterator Tmp = *this;
89 ++*this;
90 return Tmp;
91 }
92
93 iterator operator--() {
94 if (Current > Data.begin())
95 Current -= Stride;
96 return *this;
97 }
98
99 iterator operator--(int) {
100 iterator Tmp = *this;
101 --*this;
102 return Tmp;
103 }
104
105 bool operator==(const iterator I) { return I.Current == Current; }
106 bool operator!=(const iterator I) { return !(*this == I); }
107 };
108
109 iterator begin() const { return iterator(*this, Data.begin()); }
110
111 iterator end() const { return iterator(*this, Data.end()); }
112
113 size_t size() const { return Data.size() / Stride; }
114
115 bool isEmpty() const { return Data.empty(); }
116};
117
118namespace DirectX {
119class PSVRuntimeInfo {
120
121 using ResourceArray = ViewArray<dxbc::PSV::v2::ResourceBindInfo>;
122 using SigElementArray = ViewArray<dxbc::PSV::v0::SignatureElement>;
123
124 StringRef Data;
125 uint32_t Size;
126 using InfoStruct =
127 std::variant<std::monostate, dxbc::PSV::v0::RuntimeInfo,
128 dxbc::PSV::v1::RuntimeInfo, dxbc::PSV::v2::RuntimeInfo,
129 dxbc::PSV::v3::RuntimeInfo>;
130 InfoStruct BasicInfo;
131 ResourceArray Resources;
132 StringRef StringTable;
133 SmallVector<uint32_t> SemanticIndexTable;
134 SigElementArray SigInputElements;
135 SigElementArray SigOutputElements;
136 SigElementArray SigPatchOrPrimElements;
137
138 std::array<ViewArray<uint32_t>, 4> OutputVectorMasks;
139 ViewArray<uint32_t> PatchOrPrimMasks;
140 std::array<ViewArray<uint32_t>, 4> InputOutputMap;
141 ViewArray<uint32_t> InputPatchMap;
142 ViewArray<uint32_t> PatchOutputMap;
143
144public:
145 PSVRuntimeInfo(StringRef D) : Data(D), Size(0) {}
146
147 // Parsing depends on the shader kind
148 Error parse(uint16_t ShaderKind);
149
150 uint32_t getSize() const { return Size; }
151 uint32_t getResourceCount() const { return Resources.size(); }
152 ResourceArray getResources() const { return Resources; }
153
154 uint32_t getVersion() const {
155 return Size >= sizeof(dxbc::PSV::v3::RuntimeInfo)
156 ? 3
157 : (Size >= sizeof(dxbc::PSV::v2::RuntimeInfo) ? 2
158 : (Size >= sizeof(dxbc::PSV::v1::RuntimeInfo)) ? 1
159 : 0);
160 }
161
162 uint32_t getResourceStride() const { return Resources.Stride; }
163
164 const InfoStruct &getInfo() const { return BasicInfo; }
165
166 template <typename T> const T *getInfoAs() const {
167 if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(ptr: &BasicInfo))
168 return static_cast<const T *>(P);
169 if (std::is_same<T, dxbc::PSV::v3::RuntimeInfo>::value)
170 return nullptr;
171
172 if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(ptr: &BasicInfo))
173 return static_cast<const T *>(P);
174 if (std::is_same<T, dxbc::PSV::v2::RuntimeInfo>::value)
175 return nullptr;
176
177 if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(ptr: &BasicInfo))
178 return static_cast<const T *>(P);
179 if (std::is_same<T, dxbc::PSV::v1::RuntimeInfo>::value)
180 return nullptr;
181
182 if (const auto *P = std::get_if<dxbc::PSV::v0::RuntimeInfo>(ptr: &BasicInfo))
183 return static_cast<const T *>(P);
184 return nullptr;
185 }
186
187 StringRef getStringTable() const { return StringTable; }
188 ArrayRef<uint32_t> getSemanticIndexTable() const {
189 return SemanticIndexTable;
190 }
191
192 uint8_t getSigInputCount() const;
193 uint8_t getSigOutputCount() const;
194 uint8_t getSigPatchOrPrimCount() const;
195
196 SigElementArray getSigInputElements() const { return SigInputElements; }
197 SigElementArray getSigOutputElements() const { return SigOutputElements; }
198 SigElementArray getSigPatchOrPrimElements() const {
199 return SigPatchOrPrimElements;
200 }
201
202 ViewArray<uint32_t> getOutputVectorMasks(size_t Idx) const {
203 assert(Idx < 4);
204 return OutputVectorMasks[Idx];
205 }
206
207 ViewArray<uint32_t> getPatchOrPrimMasks() const { return PatchOrPrimMasks; }
208
209 ViewArray<uint32_t> getInputOutputMap(size_t Idx) const {
210 assert(Idx < 4);
211 return InputOutputMap[Idx];
212 }
213
214 ViewArray<uint32_t> getInputPatchMap() const { return InputPatchMap; }
215 ViewArray<uint32_t> getPatchOutputMap() const { return PatchOutputMap; }
216
217 uint32_t getSigElementStride() const { return SigInputElements.Stride; }
218
219 bool usesViewID() const {
220 if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
221 return P->UsesViewID != 0;
222 return false;
223 }
224
225 uint8_t getInputVectorCount() const {
226 if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
227 return P->SigInputVectors;
228 return 0;
229 }
230
231 ArrayRef<uint8_t> getOutputVectorCounts() const {
232 if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
233 return ArrayRef<uint8_t>(P->SigOutputVectors);
234 return ArrayRef<uint8_t>();
235 }
236
237 uint8_t getPatchConstOrPrimVectorCount() const {
238 if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
239 return P->GeomData.SigPatchConstOrPrimVectors;
240 return 0;
241 }
242};
243
244class Signature {
245 ViewArray<dxbc::ProgramSignatureElement> Parameters;
246 uint32_t StringTableOffset;
247 StringRef StringTable;
248
249public:
250 ViewArray<dxbc::ProgramSignatureElement>::iterator begin() const {
251 return Parameters.begin();
252 }
253
254 ViewArray<dxbc::ProgramSignatureElement>::iterator end() const {
255 return Parameters.end();
256 }
257
258 StringRef getName(uint32_t Offset) const {
259 assert(Offset >= StringTableOffset &&
260 Offset < StringTableOffset + StringTable.size() &&
261 "Offset out of range.");
262 // Name offsets are from the start of the signature data, not from the start
263 // of the string table. The header encodes the start offset of the sting
264 // table, so we convert the offset here.
265 uint32_t TableOffset = Offset - StringTableOffset;
266 return StringTable.slice(Start: TableOffset, End: StringTable.find(C: '\0', From: TableOffset));
267 }
268
269 bool isEmpty() const { return Parameters.isEmpty(); }
270
271 Error initialize(StringRef Part);
272};
273
274} // namespace DirectX
275
276class DXContainer {
277public:
278 using DXILData = std::pair<dxbc::ProgramHeader, const char *>;
279
280private:
281 DXContainer(MemoryBufferRef O);
282
283 MemoryBufferRef Data;
284 dxbc::Header Header;
285 SmallVector<uint32_t, 4> PartOffsets;
286 std::optional<DXILData> DXIL;
287 std::optional<uint64_t> ShaderFeatureFlags;
288 std::optional<dxbc::ShaderHash> Hash;
289 std::optional<DirectX::PSVRuntimeInfo> PSVInfo;
290 DirectX::Signature InputSignature;
291 DirectX::Signature OutputSignature;
292 DirectX::Signature PatchConstantSignature;
293
294 Error parseHeader();
295 Error parsePartOffsets();
296 Error parseDXILHeader(StringRef Part);
297 Error parseShaderFeatureFlags(StringRef Part);
298 Error parseHash(StringRef Part);
299 Error parsePSVInfo(StringRef Part);
300 Error parseSignature(StringRef Part, DirectX::Signature &Array);
301 friend class PartIterator;
302
303public:
304 // The PartIterator is a wrapper around the iterator for the PartOffsets
305 // member of the DXContainer. It contains a refernce to the container, and the
306 // current iterator value, as well as storage for a parsed part header.
307 class PartIterator {
308 const DXContainer &Container;
309 SmallVectorImpl<uint32_t>::const_iterator OffsetIt;
310 struct PartData {
311 dxbc::PartHeader Part;
312 uint32_t Offset;
313 StringRef Data;
314 } IteratorState;
315
316 friend class DXContainer;
317
318 PartIterator(const DXContainer &C,
319 SmallVectorImpl<uint32_t>::const_iterator It)
320 : Container(C), OffsetIt(It) {
321 if (OffsetIt == Container.PartOffsets.end())
322 updateIteratorImpl(Offset: Container.PartOffsets.back());
323 else
324 updateIterator();
325 }
326
327 // Updates the iterator's state data. This results in copying the part
328 // header into the iterator and handling any required byte swapping. This is
329 // called when incrementing or decrementing the iterator.
330 void updateIterator() {
331 if (OffsetIt != Container.PartOffsets.end())
332 updateIteratorImpl(Offset: *OffsetIt);
333 }
334
335 // Implementation for updating the iterator state based on a specified
336 // offest.
337 void updateIteratorImpl(const uint32_t Offset);
338
339 public:
340 PartIterator &operator++() {
341 if (OffsetIt == Container.PartOffsets.end())
342 return *this;
343 ++OffsetIt;
344 updateIterator();
345 return *this;
346 }
347
348 PartIterator operator++(int) {
349 PartIterator Tmp = *this;
350 ++(*this);
351 return Tmp;
352 }
353
354 bool operator==(const PartIterator &RHS) const {
355 return OffsetIt == RHS.OffsetIt;
356 }
357
358 bool operator!=(const PartIterator &RHS) const {
359 return OffsetIt != RHS.OffsetIt;
360 }
361
362 const PartData &operator*() { return IteratorState; }
363 const PartData *operator->() { return &IteratorState; }
364 };
365
366 PartIterator begin() const {
367 return PartIterator(*this, PartOffsets.begin());
368 }
369
370 PartIterator end() const { return PartIterator(*this, PartOffsets.end()); }
371
372 StringRef getData() const { return Data.getBuffer(); }
373 static Expected<DXContainer> create(MemoryBufferRef Object);
374
375 const dxbc::Header &getHeader() const { return Header; }
376
377 const std::optional<DXILData> &getDXIL() const { return DXIL; }
378
379 std::optional<uint64_t> getShaderFeatureFlags() const {
380 return ShaderFeatureFlags;
381 }
382
383 std::optional<dxbc::ShaderHash> getShaderHash() const { return Hash; }
384
385 const std::optional<DirectX::PSVRuntimeInfo> &getPSVInfo() const {
386 return PSVInfo;
387 };
388
389 const DirectX::Signature &getInputSignature() const { return InputSignature; }
390 const DirectX::Signature &getOutputSignature() const {
391 return OutputSignature;
392 }
393 const DirectX::Signature &getPatchConstantSignature() const {
394 return PatchConstantSignature;
395 }
396};
397
398} // namespace object
399} // namespace llvm
400
401#endif // LLVM_OBJECT_DXCONTAINER_H
402

source code of llvm/include/llvm/Object/DXContainer.h