1 | //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This provides an abstract class for HLSL code generation. Concrete |
10 | // subclasses of this implement code generation for specific HLSL |
11 | // runtime libraries. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "CGHLSLRuntime.h" |
16 | #include "CGDebugInfo.h" |
17 | #include "CodeGenModule.h" |
18 | #include "clang/AST/Decl.h" |
19 | #include "clang/Basic/TargetOptions.h" |
20 | #include "llvm/IR/Metadata.h" |
21 | #include "llvm/IR/Module.h" |
22 | #include "llvm/Support/FormatVariadic.h" |
23 | |
24 | using namespace clang; |
25 | using namespace CodeGen; |
26 | using namespace clang::hlsl; |
27 | using namespace llvm; |
28 | |
29 | namespace { |
30 | |
31 | void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) { |
32 | // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs. |
33 | // Assume ValVersionStr is legal here. |
34 | VersionTuple Version; |
35 | if (Version.tryParse(string: ValVersionStr) || Version.getBuild() || |
36 | Version.getSubminor() || !Version.getMinor()) { |
37 | return; |
38 | } |
39 | |
40 | uint64_t Major = Version.getMajor(); |
41 | uint64_t Minor = *Version.getMinor(); |
42 | |
43 | auto &Ctx = M.getContext(); |
44 | IRBuilder<> B(M.getContext()); |
45 | MDNode *Val = MDNode::get(Context&: Ctx, MDs: {ConstantAsMetadata::get(C: B.getInt32(C: Major)), |
46 | ConstantAsMetadata::get(C: B.getInt32(C: Minor))}); |
47 | StringRef DXILValKey = "dx.valver" ; |
48 | auto *DXILValMD = M.getOrInsertNamedMetadata(Name: DXILValKey); |
49 | DXILValMD->addOperand(M: Val); |
50 | } |
51 | void addDisableOptimizations(llvm::Module &M) { |
52 | StringRef Key = "dx.disable_optimizations" ; |
53 | M.addModuleFlag(Behavior: llvm::Module::ModFlagBehavior::Override, Key, Val: 1); |
54 | } |
55 | // cbuffer will be translated into global variable in special address space. |
56 | // If translate into C, |
57 | // cbuffer A { |
58 | // float a; |
59 | // float b; |
60 | // } |
61 | // float foo() { return a + b; } |
62 | // |
63 | // will be translated into |
64 | // |
65 | // struct A { |
66 | // float a; |
67 | // float b; |
68 | // } cbuffer_A __attribute__((address_space(4))); |
69 | // float foo() { return cbuffer_A.a + cbuffer_A.b; } |
70 | // |
71 | // layoutBuffer will create the struct A type. |
72 | // replaceBuffer will replace use of global variable a and b with cbuffer_A.a |
73 | // and cbuffer_A.b. |
74 | // |
75 | void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) { |
76 | if (Buf.Constants.empty()) |
77 | return; |
78 | |
79 | std::vector<llvm::Type *> EltTys; |
80 | for (auto &Const : Buf.Constants) { |
81 | GlobalVariable *GV = Const.first; |
82 | Const.second = EltTys.size(); |
83 | llvm::Type *Ty = GV->getValueType(); |
84 | EltTys.emplace_back(args&: Ty); |
85 | } |
86 | Buf.LayoutStruct = llvm::StructType::get(Context&: EltTys[0]->getContext(), Elements: EltTys); |
87 | } |
88 | |
89 | GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) { |
90 | // Create global variable for CB. |
91 | GlobalVariable *CBGV = new GlobalVariable( |
92 | Buf.LayoutStruct, /*isConstant*/ true, |
93 | GlobalValue::LinkageTypes::ExternalLinkage, nullptr, |
94 | llvm::formatv(Fmt: "{0}{1}" , Vals&: Buf.Name, Vals: Buf.IsCBuffer ? ".cb." : ".tb." ), |
95 | GlobalValue::NotThreadLocal); |
96 | |
97 | IRBuilder<> B(CBGV->getContext()); |
98 | Value *ZeroIdx = B.getInt32(C: 0); |
99 | // Replace Const use with CB use. |
100 | for (auto &[GV, Offset] : Buf.Constants) { |
101 | Value *GEP = |
102 | B.CreateGEP(Ty: Buf.LayoutStruct, Ptr: CBGV, IdxList: {ZeroIdx, B.getInt32(C: Offset)}); |
103 | |
104 | assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() && |
105 | "constant type mismatch" ); |
106 | |
107 | // Replace. |
108 | GV->replaceAllUsesWith(V: GEP); |
109 | // Erase GV. |
110 | GV->removeDeadConstantUsers(); |
111 | GV->eraseFromParent(); |
112 | } |
113 | return CBGV; |
114 | } |
115 | |
116 | } // namespace |
117 | |
118 | llvm::Triple::ArchType CGHLSLRuntime::getArch() { |
119 | return CGM.getTarget().getTriple().getArch(); |
120 | } |
121 | |
122 | void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) { |
123 | if (D->getStorageClass() == SC_Static) { |
124 | // For static inside cbuffer, take as global static. |
125 | // Don't add to cbuffer. |
126 | CGM.EmitGlobal(D); |
127 | return; |
128 | } |
129 | |
130 | auto *GV = cast<GlobalVariable>(Val: CGM.GetAddrOfGlobalVar(D)); |
131 | // Add debug info for constVal. |
132 | if (CGDebugInfo *DI = CGM.getModuleDebugInfo()) |
133 | if (CGM.getCodeGenOpts().getDebugInfo() >= |
134 | codegenoptions::DebugInfoKind::LimitedDebugInfo) |
135 | DI->EmitGlobalVariable(GV: cast<GlobalVariable>(Val: GV), Decl: D); |
136 | |
137 | // FIXME: support packoffset. |
138 | // See https://github.com/llvm/llvm-project/issues/57914. |
139 | uint32_t Offset = 0; |
140 | bool HasUserOffset = false; |
141 | |
142 | unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX; |
143 | CB.Constants.emplace_back(args: std::make_pair(x&: GV, y&: LowerBound)); |
144 | } |
145 | |
146 | void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) { |
147 | for (Decl *it : DC->decls()) { |
148 | if (auto *ConstDecl = dyn_cast<VarDecl>(Val: it)) { |
149 | addConstant(D: ConstDecl, CB); |
150 | } else if (isa<CXXRecordDecl, EmptyDecl>(Val: it)) { |
151 | // Nothing to do for this declaration. |
152 | } else if (isa<FunctionDecl>(Val: it)) { |
153 | // A function within an cbuffer is effectively a top-level function, |
154 | // as it only refers to globally scoped declarations. |
155 | CGM.EmitTopLevelDecl(D: it); |
156 | } |
157 | } |
158 | } |
159 | |
160 | void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) { |
161 | Buffers.emplace_back(Args: Buffer(D)); |
162 | addBufferDecls(D, Buffers.back()); |
163 | } |
164 | |
165 | void CGHLSLRuntime::finishCodeGen() { |
166 | auto &TargetOpts = CGM.getTarget().getTargetOpts(); |
167 | llvm::Module &M = CGM.getModule(); |
168 | Triple T(M.getTargetTriple()); |
169 | if (T.getArch() == Triple::ArchType::dxil) |
170 | addDxilValVersion(ValVersionStr: TargetOpts.DxilValidatorVersion, M); |
171 | |
172 | generateGlobalCtorDtorCalls(); |
173 | if (CGM.getCodeGenOpts().OptimizationLevel == 0) |
174 | addDisableOptimizations(M); |
175 | |
176 | const DataLayout &DL = M.getDataLayout(); |
177 | |
178 | for (auto &Buf : Buffers) { |
179 | layoutBuffer(Buf, DL); |
180 | GlobalVariable *GV = replaceBuffer(Buf); |
181 | M.insertGlobalVariable(GV); |
182 | llvm::hlsl::ResourceClass RC = Buf.IsCBuffer |
183 | ? llvm::hlsl::ResourceClass::CBuffer |
184 | : llvm::hlsl::ResourceClass::SRV; |
185 | llvm::hlsl::ResourceKind RK = Buf.IsCBuffer |
186 | ? llvm::hlsl::ResourceKind::CBuffer |
187 | : llvm::hlsl::ResourceKind::TBuffer; |
188 | addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false, |
189 | ET: llvm::hlsl::ElementType::Invalid, Binding&: Buf.Binding); |
190 | } |
191 | } |
192 | |
193 | CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D) |
194 | : Name(D->getName()), IsCBuffer(D->isCBuffer()), |
195 | Binding(D->getAttr<HLSLResourceBindingAttr>()) {} |
196 | |
197 | void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV, |
198 | llvm::hlsl::ResourceClass RC, |
199 | llvm::hlsl::ResourceKind RK, |
200 | bool IsROV, |
201 | llvm::hlsl::ElementType ET, |
202 | BufferResBinding &Binding) { |
203 | llvm::Module &M = CGM.getModule(); |
204 | |
205 | NamedMDNode *ResourceMD = nullptr; |
206 | switch (RC) { |
207 | case llvm::hlsl::ResourceClass::UAV: |
208 | ResourceMD = M.getOrInsertNamedMetadata(Name: "hlsl.uavs" ); |
209 | break; |
210 | case llvm::hlsl::ResourceClass::SRV: |
211 | ResourceMD = M.getOrInsertNamedMetadata(Name: "hlsl.srvs" ); |
212 | break; |
213 | case llvm::hlsl::ResourceClass::CBuffer: |
214 | ResourceMD = M.getOrInsertNamedMetadata(Name: "hlsl.cbufs" ); |
215 | break; |
216 | default: |
217 | assert(false && "Unsupported buffer type!" ); |
218 | return; |
219 | } |
220 | assert(ResourceMD != nullptr && |
221 | "ResourceMD must have been set by the switch above." ); |
222 | |
223 | llvm::hlsl::FrontendResource Res( |
224 | GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space); |
225 | ResourceMD->addOperand(M: Res.getMetadata()); |
226 | } |
227 | |
228 | static llvm::hlsl::ElementType |
229 | calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy) { |
230 | using llvm::hlsl::ElementType; |
231 | |
232 | // TODO: We may need to update this when we add things like ByteAddressBuffer |
233 | // that don't have a template parameter (or, indeed, an element type). |
234 | const auto *TST = ResourceTy->getAs<TemplateSpecializationType>(); |
235 | assert(TST && "Resource types must be template specializations" ); |
236 | ArrayRef<TemplateArgument> Args = TST->template_arguments(); |
237 | assert(!Args.empty() && "Resource has no element type" ); |
238 | |
239 | // At this point we have a resource with an element type, so we can assume |
240 | // that it's valid or we would have diagnosed the error earlier. |
241 | QualType ElTy = Args[0].getAsType(); |
242 | |
243 | // We should either have a basic type or a vector of a basic type. |
244 | if (const auto *VecTy = ElTy->getAs<clang::VectorType>()) |
245 | ElTy = VecTy->getElementType(); |
246 | |
247 | if (ElTy->isSignedIntegerType()) { |
248 | switch (Context.getTypeSize(T: ElTy)) { |
249 | case 16: |
250 | return ElementType::I16; |
251 | case 32: |
252 | return ElementType::I32; |
253 | case 64: |
254 | return ElementType::I64; |
255 | } |
256 | } else if (ElTy->isUnsignedIntegerType()) { |
257 | switch (Context.getTypeSize(T: ElTy)) { |
258 | case 16: |
259 | return ElementType::U16; |
260 | case 32: |
261 | return ElementType::U32; |
262 | case 64: |
263 | return ElementType::U64; |
264 | } |
265 | } else if (ElTy->isSpecificBuiltinType(K: BuiltinType::Half)) |
266 | return ElementType::F16; |
267 | else if (ElTy->isSpecificBuiltinType(K: BuiltinType::Float)) |
268 | return ElementType::F32; |
269 | else if (ElTy->isSpecificBuiltinType(K: BuiltinType::Double)) |
270 | return ElementType::F64; |
271 | |
272 | // TODO: We need to handle unorm/snorm float types here once we support them |
273 | llvm_unreachable("Invalid element type for resource" ); |
274 | } |
275 | |
276 | void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) { |
277 | const Type *Ty = D->getType()->getPointeeOrArrayElementType(); |
278 | if (!Ty) |
279 | return; |
280 | const auto *RD = Ty->getAsCXXRecordDecl(); |
281 | if (!RD) |
282 | return; |
283 | const auto *Attr = RD->getAttr<HLSLResourceAttr>(); |
284 | if (!Attr) |
285 | return; |
286 | |
287 | llvm::hlsl::ResourceClass RC = Attr->getResourceClass(); |
288 | llvm::hlsl::ResourceKind RK = Attr->getResourceKind(); |
289 | bool IsROV = Attr->getIsROV(); |
290 | llvm::hlsl::ElementType ET = calculateElementType(Context: CGM.getContext(), ResourceTy: Ty); |
291 | |
292 | BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>()); |
293 | addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding); |
294 | } |
295 | |
296 | CGHLSLRuntime::BufferResBinding::BufferResBinding( |
297 | HLSLResourceBindingAttr *Binding) { |
298 | if (Binding) { |
299 | llvm::APInt RegInt(64, 0); |
300 | Binding->getSlot().substr(1).getAsInteger(10, RegInt); |
301 | Reg = RegInt.getLimitedValue(); |
302 | llvm::APInt SpaceInt(64, 0); |
303 | Binding->getSpace().substr(5).getAsInteger(10, SpaceInt); |
304 | Space = SpaceInt.getLimitedValue(); |
305 | } else { |
306 | Space = 0; |
307 | } |
308 | } |
309 | |
310 | void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes( |
311 | const FunctionDecl *FD, llvm::Function *Fn) { |
312 | const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); |
313 | assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr" ); |
314 | const StringRef ShaderAttrKindStr = "hlsl.shader" ; |
315 | Fn->addFnAttr(ShaderAttrKindStr, |
316 | ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType())); |
317 | if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) { |
318 | const StringRef NumThreadsKindStr = "hlsl.numthreads" ; |
319 | std::string NumThreadsStr = |
320 | formatv("{0},{1},{2}" , NumThreadsAttr->getX(), NumThreadsAttr->getY(), |
321 | NumThreadsAttr->getZ()); |
322 | Fn->addFnAttr(Kind: NumThreadsKindStr, Val: NumThreadsStr); |
323 | } |
324 | } |
325 | |
326 | static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) { |
327 | if (const auto *VT = dyn_cast<FixedVectorType>(Val: Ty)) { |
328 | Value *Result = PoisonValue::get(T: Ty); |
329 | for (unsigned I = 0; I < VT->getNumElements(); ++I) { |
330 | Value *Elt = B.CreateCall(Callee: F, Args: {B.getInt32(C: I)}); |
331 | Result = B.CreateInsertElement(Vec: Result, NewElt: Elt, Idx: I); |
332 | } |
333 | return Result; |
334 | } |
335 | return B.CreateCall(Callee: F, Args: {B.getInt32(C: 0)}); |
336 | } |
337 | |
338 | llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, |
339 | const ParmVarDecl &D, |
340 | llvm::Type *Ty) { |
341 | assert(D.hasAttrs() && "Entry parameter missing annotation attribute!" ); |
342 | if (D.hasAttr<HLSLSV_GroupIndexAttr>()) { |
343 | llvm::Function *DxGroupIndex = |
344 | CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group); |
345 | return B.CreateCall(Callee: FunctionCallee(DxGroupIndex)); |
346 | } |
347 | if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) { |
348 | llvm::Function *ThreadIDIntrinsic = |
349 | CGM.getIntrinsic(IID: getThreadIdIntrinsic()); |
350 | return buildVectorInput(B, F: ThreadIDIntrinsic, Ty); |
351 | } |
352 | assert(false && "Unhandled parameter attribute" ); |
353 | return nullptr; |
354 | } |
355 | |
356 | void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD, |
357 | llvm::Function *Fn) { |
358 | llvm::Module &M = CGM.getModule(); |
359 | llvm::LLVMContext &Ctx = M.getContext(); |
360 | auto *EntryTy = llvm::FunctionType::get(Result: llvm::Type::getVoidTy(C&: Ctx), isVarArg: false); |
361 | Function *EntryFn = |
362 | Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M); |
363 | |
364 | // Copy function attributes over, we have no argument or return attributes |
365 | // that can be valid on the real entry. |
366 | AttributeList NewAttrs = AttributeList::get(C&: Ctx, Index: AttributeList::FunctionIndex, |
367 | Attrs: Fn->getAttributes().getFnAttrs()); |
368 | EntryFn->setAttributes(NewAttrs); |
369 | setHLSLEntryAttributes(FD, Fn: EntryFn); |
370 | |
371 | // Set the called function as internal linkage. |
372 | Fn->setLinkage(GlobalValue::InternalLinkage); |
373 | |
374 | BasicBlock *BB = BasicBlock::Create(Context&: Ctx, Name: "entry" , Parent: EntryFn); |
375 | IRBuilder<> B(BB); |
376 | llvm::SmallVector<Value *> Args; |
377 | // FIXME: support struct parameters where semantics are on members. |
378 | // See: https://github.com/llvm/llvm-project/issues/57874 |
379 | unsigned SRetOffset = 0; |
380 | for (const auto &Param : Fn->args()) { |
381 | if (Param.hasStructRetAttr()) { |
382 | // FIXME: support output. |
383 | // See: https://github.com/llvm/llvm-project/issues/57874 |
384 | SRetOffset = 1; |
385 | Args.emplace_back(Args: PoisonValue::get(T: Param.getType())); |
386 | continue; |
387 | } |
388 | const ParmVarDecl *PD = FD->getParamDecl(i: Param.getArgNo() - SRetOffset); |
389 | Args.push_back(Elt: emitInputSemantic(B, D: *PD, Ty: Param.getType())); |
390 | } |
391 | |
392 | CallInst *CI = B.CreateCall(Callee: FunctionCallee(Fn), Args); |
393 | (void)CI; |
394 | // FIXME: Handle codegen for return type semantics. |
395 | // See: https://github.com/llvm/llvm-project/issues/57875 |
396 | B.CreateRetVoid(); |
397 | } |
398 | |
399 | static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M, |
400 | bool CtorOrDtor) { |
401 | const auto *GV = |
402 | M.getNamedGlobal(Name: CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors" ); |
403 | if (!GV) |
404 | return; |
405 | const auto *CA = dyn_cast<ConstantArray>(Val: GV->getInitializer()); |
406 | if (!CA) |
407 | return; |
408 | // The global_ctor array elements are a struct [Priority, Fn *, COMDat]. |
409 | // HLSL neither supports priorities or COMDat values, so we will check those |
410 | // in an assert but not handle them. |
411 | |
412 | llvm::SmallVector<Function *> CtorFns; |
413 | for (const auto &Ctor : CA->operands()) { |
414 | if (isa<ConstantAggregateZero>(Val: Ctor)) |
415 | continue; |
416 | ConstantStruct *CS = cast<ConstantStruct>(Val: Ctor); |
417 | |
418 | assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 && |
419 | "HLSL doesn't support setting priority for global ctors." ); |
420 | assert(isa<ConstantPointerNull>(CS->getOperand(2)) && |
421 | "HLSL doesn't support COMDat for global ctors." ); |
422 | Fns.push_back(Elt: cast<Function>(Val: CS->getOperand(i_nocapture: 1))); |
423 | } |
424 | } |
425 | |
426 | void CGHLSLRuntime::generateGlobalCtorDtorCalls() { |
427 | llvm::Module &M = CGM.getModule(); |
428 | SmallVector<Function *> CtorFns; |
429 | SmallVector<Function *> DtorFns; |
430 | gatherFunctions(Fns&: CtorFns, M, CtorOrDtor: true); |
431 | gatherFunctions(Fns&: DtorFns, M, CtorOrDtor: false); |
432 | |
433 | // Insert a call to the global constructor at the beginning of the entry block |
434 | // to externally exported functions. This is a bit of a hack, but HLSL allows |
435 | // global constructors, but doesn't support driver initialization of globals. |
436 | for (auto &F : M.functions()) { |
437 | if (!F.hasFnAttribute(Kind: "hlsl.shader" )) |
438 | continue; |
439 | IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin()); |
440 | for (auto *Fn : CtorFns) |
441 | B.CreateCall(Callee: FunctionCallee(Fn)); |
442 | |
443 | // Insert global dtors before the terminator of the last instruction |
444 | B.SetInsertPoint(F.back().getTerminator()); |
445 | for (auto *Fn : DtorFns) |
446 | B.CreateCall(Callee: FunctionCallee(Fn)); |
447 | } |
448 | |
449 | // No need to keep global ctors/dtors for non-lib profile after call to |
450 | // ctors/dtors added for entry. |
451 | Triple T(M.getTargetTriple()); |
452 | if (T.getEnvironment() != Triple::EnvironmentType::Library) { |
453 | if (auto *GV = M.getNamedGlobal(Name: "llvm.global_ctors" )) |
454 | GV->eraseFromParent(); |
455 | if (auto *GV = M.getNamedGlobal(Name: "llvm.global_dtors" )) |
456 | GV->eraseFromParent(); |
457 | } |
458 | } |
459 | |