1//===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file This file contains class to help build DXIL op functions.
10//===----------------------------------------------------------------------===//
11
12#include "DXILOpBuilder.h"
13#include "DXILConstants.h"
14#include "llvm/IR/IRBuilder.h"
15#include "llvm/IR/Module.h"
16#include "llvm/Support/DXILABI.h"
17#include "llvm/Support/ErrorHandling.h"
18
19using namespace llvm;
20using namespace llvm::dxil;
21
22constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
23
24namespace {
25
26enum OverloadKind : uint16_t {
27 VOID = 1,
28 HALF = 1 << 1,
29 FLOAT = 1 << 2,
30 DOUBLE = 1 << 3,
31 I1 = 1 << 4,
32 I8 = 1 << 5,
33 I16 = 1 << 6,
34 I32 = 1 << 7,
35 I64 = 1 << 8,
36 UserDefineType = 1 << 9,
37 ObjectType = 1 << 10,
38};
39
40} // namespace
41
42static const char *getOverloadTypeName(OverloadKind Kind) {
43 switch (Kind) {
44 case OverloadKind::HALF:
45 return "f16";
46 case OverloadKind::FLOAT:
47 return "f32";
48 case OverloadKind::DOUBLE:
49 return "f64";
50 case OverloadKind::I1:
51 return "i1";
52 case OverloadKind::I8:
53 return "i8";
54 case OverloadKind::I16:
55 return "i16";
56 case OverloadKind::I32:
57 return "i32";
58 case OverloadKind::I64:
59 return "i64";
60 case OverloadKind::VOID:
61 case OverloadKind::ObjectType:
62 case OverloadKind::UserDefineType:
63 break;
64 }
65 llvm_unreachable("invalid overload type for name");
66 return "void";
67}
68
69static OverloadKind getOverloadKind(Type *Ty) {
70 Type::TypeID T = Ty->getTypeID();
71 switch (T) {
72 case Type::VoidTyID:
73 return OverloadKind::VOID;
74 case Type::HalfTyID:
75 return OverloadKind::HALF;
76 case Type::FloatTyID:
77 return OverloadKind::FLOAT;
78 case Type::DoubleTyID:
79 return OverloadKind::DOUBLE;
80 case Type::IntegerTyID: {
81 IntegerType *ITy = cast<IntegerType>(Val: Ty);
82 unsigned Bits = ITy->getBitWidth();
83 switch (Bits) {
84 case 1:
85 return OverloadKind::I1;
86 case 8:
87 return OverloadKind::I8;
88 case 16:
89 return OverloadKind::I16;
90 case 32:
91 return OverloadKind::I32;
92 case 64:
93 return OverloadKind::I64;
94 default:
95 llvm_unreachable("invalid overload type");
96 return OverloadKind::VOID;
97 }
98 }
99 case Type::PointerTyID:
100 return OverloadKind::UserDefineType;
101 case Type::StructTyID:
102 return OverloadKind::ObjectType;
103 default:
104 llvm_unreachable("invalid overload type");
105 return OverloadKind::VOID;
106 }
107}
108
109static std::string getTypeName(OverloadKind Kind, Type *Ty) {
110 if (Kind < OverloadKind::UserDefineType) {
111 return getOverloadTypeName(Kind);
112 } else if (Kind == OverloadKind::UserDefineType) {
113 StructType *ST = cast<StructType>(Val: Ty);
114 return ST->getStructName().str();
115 } else if (Kind == OverloadKind::ObjectType) {
116 StructType *ST = cast<StructType>(Val: Ty);
117 return ST->getStructName().str();
118 } else {
119 std::string Str;
120 raw_string_ostream OS(Str);
121 Ty->print(O&: OS);
122 return OS.str();
123 }
124}
125
126// Static properties.
127struct OpCodeProperty {
128 dxil::OpCode OpCode;
129 // Offset in DXILOpCodeNameTable.
130 unsigned OpCodeNameOffset;
131 dxil::OpCodeClass OpCodeClass;
132 // Offset in DXILOpCodeClassNameTable.
133 unsigned OpCodeClassNameOffset;
134 uint16_t OverloadTys;
135 llvm::Attribute::AttrKind FuncAttr;
136 int OverloadParamIndex; // parameter index which control the overload.
137 // When < 0, should be only 1 overload type.
138 unsigned NumOfParameters; // Number of parameters include return value.
139 unsigned ParameterTableOffset; // Offset in ParameterTable.
140};
141
142// Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
143// getOpCodeParameterKind which generated by tableGen.
144#define DXIL_OP_OPERATION_TABLE
145#include "DXILOperation.inc"
146#undef DXIL_OP_OPERATION_TABLE
147
148static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
149 const OpCodeProperty &Prop) {
150 if (Kind == OverloadKind::VOID) {
151 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
152 }
153 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
154 getTypeName(Kind, Ty))
155 .str();
156}
157
158static std::string constructOverloadTypeName(OverloadKind Kind,
159 StringRef TypeName) {
160 if (Kind == OverloadKind::VOID)
161 return TypeName.str();
162
163 assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
164 return (Twine(TypeName) + getOverloadTypeName(Kind)).str();
165}
166
167static StructType *getOrCreateStructType(StringRef Name,
168 ArrayRef<Type *> EltTys,
169 LLVMContext &Ctx) {
170 StructType *ST = StructType::getTypeByName(C&: Ctx, Name);
171 if (ST)
172 return ST;
173
174 return StructType::create(Context&: Ctx, Elements: EltTys, Name);
175}
176
177static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
178 OverloadKind Kind = getOverloadKind(Ty: OverloadTy);
179 std::string TypeName = constructOverloadTypeName(Kind, TypeName: "dx.types.ResRet.");
180 Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
181 Type::getInt32Ty(C&: Ctx)};
182 return getOrCreateStructType(Name: TypeName, EltTys: FieldTypes, Ctx);
183}
184
185static StructType *getHandleType(LLVMContext &Ctx) {
186 return getOrCreateStructType(Name: "dx.types.Handle", EltTys: PointerType::getUnqual(C&: Ctx),
187 Ctx);
188}
189
190static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
191 auto &Ctx = OverloadTy->getContext();
192 switch (Kind) {
193 case ParameterKind::Void:
194 return Type::getVoidTy(C&: Ctx);
195 case ParameterKind::Half:
196 return Type::getHalfTy(C&: Ctx);
197 case ParameterKind::Float:
198 return Type::getFloatTy(C&: Ctx);
199 case ParameterKind::Double:
200 return Type::getDoubleTy(C&: Ctx);
201 case ParameterKind::I1:
202 return Type::getInt1Ty(C&: Ctx);
203 case ParameterKind::I8:
204 return Type::getInt8Ty(C&: Ctx);
205 case ParameterKind::I16:
206 return Type::getInt16Ty(C&: Ctx);
207 case ParameterKind::I32:
208 return Type::getInt32Ty(C&: Ctx);
209 case ParameterKind::I64:
210 return Type::getInt64Ty(C&: Ctx);
211 case ParameterKind::Overload:
212 return OverloadTy;
213 case ParameterKind::ResourceRet:
214 return getResRetType(OverloadTy, Ctx);
215 case ParameterKind::DXILHandle:
216 return getHandleType(Ctx);
217 default:
218 break;
219 }
220 llvm_unreachable("Invalid parameter kind");
221 return nullptr;
222}
223
224/// Construct DXIL function type. This is the type of a function with
225/// the following prototype
226/// OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)
227/// <param-types> are constructed from types in Prop.
228/// \param Prop Structure containing DXIL Operation properties based on
229/// its specification in DXIL.td.
230/// \param OverloadTy Return type to be used to construct DXIL function type.
231static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
232 Type *ReturnTy, Type *OverloadTy) {
233 SmallVector<Type *> ArgTys;
234
235 auto ParamKinds = getOpCodeParameterKind(*Prop);
236
237 // Add ReturnTy as return type of the function
238 ArgTys.emplace_back(Args&: ReturnTy);
239
240 // Add DXIL Opcode value type viz., Int32 as first argument
241 ArgTys.emplace_back(Args: Type::getInt32Ty(C&: OverloadTy->getContext()));
242
243 // Add DXIL Operation parameter types as specified in DXIL properties
244 for (unsigned I = 0; I < Prop->NumOfParameters; ++I) {
245 ParameterKind Kind = ParamKinds[I];
246 ArgTys.emplace_back(Args: getTypeFromParameterKind(Kind, OverloadTy));
247 }
248 return FunctionType::get(
249 Result: ArgTys[0], Params: ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), isVarArg: false);
250}
251
252namespace llvm {
253namespace dxil {
254
255CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
256 Type *OverloadTy,
257 SmallVector<Value *> Args) {
258 const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
259
260 OverloadKind Kind = getOverloadKind(Ty: OverloadTy);
261 if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
262 report_fatal_error(reason: "Invalid Overload Type", /* gen_crash_diag=*/false);
263 }
264
265 std::string DXILFnName = constructOverloadName(Kind, Ty: OverloadTy, Prop: *Prop);
266 FunctionCallee DXILFn;
267 // Get the function with name DXILFnName, if one exists
268 if (auto *Func = M.getFunction(DXILFnName)) {
269 DXILFn = FunctionCallee(Func);
270 } else {
271 // Construct and add a function with name DXILFnName
272 FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
273 DXILFn = M.getOrInsertFunction(Name: DXILFnName, T: DXILOpFT);
274 }
275
276 return B.CreateCall(Callee: DXILFn, Args);
277}
278
279Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
280
281 const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
282 // If DXIL Op has no overload parameter, just return the
283 // precise return type specified.
284 if (Prop->OverloadParamIndex < 0) {
285 auto &Ctx = FT->getContext();
286 switch (Prop->OverloadTys) {
287 case OverloadKind::VOID:
288 return Type::getVoidTy(C&: Ctx);
289 case OverloadKind::HALF:
290 return Type::getHalfTy(C&: Ctx);
291 case OverloadKind::FLOAT:
292 return Type::getFloatTy(C&: Ctx);
293 case OverloadKind::DOUBLE:
294 return Type::getDoubleTy(C&: Ctx);
295 case OverloadKind::I1:
296 return Type::getInt1Ty(C&: Ctx);
297 case OverloadKind::I8:
298 return Type::getInt8Ty(C&: Ctx);
299 case OverloadKind::I16:
300 return Type::getInt16Ty(C&: Ctx);
301 case OverloadKind::I32:
302 return Type::getInt32Ty(C&: Ctx);
303 case OverloadKind::I64:
304 return Type::getInt64Ty(C&: Ctx);
305 default:
306 llvm_unreachable("invalid overload type");
307 return nullptr;
308 }
309 }
310
311 // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
312 Type *OverloadType = FT->getReturnType();
313 if (Prop->OverloadParamIndex != 0) {
314 // Skip Return Type.
315 OverloadType = FT->getParamType(i: Prop->OverloadParamIndex - 1);
316 }
317
318 auto ParamKinds = getOpCodeParameterKind(*Prop);
319 auto Kind = ParamKinds[Prop->OverloadParamIndex];
320 // For ResRet and CBufferRet, OverloadTy is in field of StructType.
321 if (Kind == ParameterKind::CBufferRet ||
322 Kind == ParameterKind::ResourceRet) {
323 auto *ST = cast<StructType>(Val: OverloadType);
324 OverloadType = ST->getElementType(N: 0);
325 }
326 return OverloadType;
327}
328
329const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) {
330 return ::getOpCodeName(DXILOp);
331}
332} // namespace dxil
333} // namespace llvm
334

source code of llvm/lib/Target/DirectX/DXILOpBuilder.cpp