1//===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert vulkan launch call into a sequence of
10// Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11// don't expose separate external functions in IR for each of them, instead we
12// expose a few external functions to wrapper libraries which manages Vulkan
13// runtime.
14//
15//===----------------------------------------------------------------------===//
16
17#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
18
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Builders.h"
22#include "mlir/IR/BuiltinOps.h"
23#include "mlir/Pass/Pass.h"
24#include "llvm/ADT/SmallString.h"
25#include "llvm/Support/FormatVariadic.h"
26
27namespace mlir {
28#define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS
29#include "mlir/Conversion/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33
34static constexpr const char *kCInterfaceVulkanLaunch =
35 "_mlir_ciface_vulkanLaunch";
36static constexpr const char *kDeinitVulkan = "deinitVulkan";
37static constexpr const char *kRunOnVulkan = "runOnVulkan";
38static constexpr const char *kInitVulkan = "initVulkan";
39static constexpr const char *kSetBinaryShader = "setBinaryShader";
40static constexpr const char *kSetEntryPoint = "setEntryPoint";
41static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
42static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
43static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
44static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
45static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types";
46static constexpr const char *kVulkanLaunch = "vulkanLaunch";
47
48namespace {
49
50/// A pass to convert vulkan launch call op into a sequence of Vulkan
51/// runtime calls in the following order:
52///
53/// * initVulkan -- initializes vulkan runtime
54/// * bindMemRef -- binds memref
55/// * setBinaryShader -- sets the binary shader data
56/// * setEntryPoint -- sets the entry point name
57/// * setNumWorkGroups -- sets the number of a local workgroups
58/// * runOnVulkan -- runs vulkan runtime
59/// * deinitVulkan -- deinitializes vulkan runtime
60///
61class VulkanLaunchFuncToVulkanCallsPass
62 : public impl::ConvertVulkanLaunchFuncToVulkanCallsPassBase<
63 VulkanLaunchFuncToVulkanCallsPass> {
64private:
65 void initializeCachedTypes() {
66 llvmFloatType = Float32Type::get(&getContext());
67 llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
68 llvmPointerType = LLVM::LLVMPointerType::get(&getContext());
69 llvmInt32Type = IntegerType::get(&getContext(), 32);
70 llvmInt64Type = IntegerType::get(&getContext(), 64);
71 }
72
73 Type getMemRefType(uint32_t rank, Type elemenType) {
74 // According to the MLIR doc memref argument is converted into a
75 // pointer-to-struct argument of type:
76 // template <typename Elem, size_t Rank>
77 // struct {
78 // Elem *allocated;
79 // Elem *aligned;
80 // int64_t offset;
81 // int64_t sizes[Rank]; // omitted when rank == 0
82 // int64_t strides[Rank]; // omitted when rank == 0
83 // };
84 auto llvmArrayRankElementSizeType =
85 LLVM::LLVMArrayType::get(getInt64Type(), rank);
86
87 // Create a type
88 // `!llvm<"{ `element-type`*, `element-type`*, i64,
89 // [`rank` x i64], [`rank` x i64]}">`.
90 return LLVM::LLVMStructType::getLiteral(
91 context: &getContext(),
92 types: {llvmPointerType, llvmPointerType, getInt64Type(),
93 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
94 }
95
96 Type getVoidType() { return llvmVoidType; }
97 Type getPointerType() { return llvmPointerType; }
98 Type getInt32Type() { return llvmInt32Type; }
99 Type getInt64Type() { return llvmInt64Type; }
100
101 /// Creates an LLVM global for the given `name`.
102 Value createEntryPointNameConstant(StringRef name, Location loc,
103 OpBuilder &builder);
104
105 /// Declares all needed runtime functions.
106 void declareVulkanFunctions(Location loc);
107
108 /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
109 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
110 return (callOp.getCallee() && *callOp.getCallee() == kVulkanLaunch &&
111 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
112 }
113
114 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
115 /// op.
116 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
117 return (callOp.getCallee() &&
118 *callOp.getCallee() == kCInterfaceVulkanLaunch &&
119 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
120 }
121
122 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
123 /// runtime calls.
124 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
125
126 /// Creates call to `bindMemRef` for each memref operand.
127 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
128 Value vulkanRuntime);
129
130 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
131 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
132
133 /// Deduces a rank from the given 'launchCallArg`.
134 LogicalResult deduceMemRefRank(Value launchCallArg, uint32_t &rank);
135
136 /// Returns a string representation from the given `type`.
137 StringRef stringifyType(Type type) {
138 if (isa<Float32Type>(type))
139 return "Float";
140 if (isa<Float16Type>(type))
141 return "Half";
142 if (auto intType = dyn_cast<IntegerType>(type)) {
143 if (intType.getWidth() == 32)
144 return "Int32";
145 if (intType.getWidth() == 16)
146 return "Int16";
147 if (intType.getWidth() == 8)
148 return "Int8";
149 }
150
151 llvm_unreachable("unsupported type");
152 }
153
154public:
155 using Base::Base;
156
157 void runOnOperation() override;
158
159private:
160 Type llvmFloatType;
161 Type llvmVoidType;
162 Type llvmPointerType;
163 Type llvmInt32Type;
164 Type llvmInt64Type;
165
166 struct SPIRVAttributes {
167 StringAttr blob;
168 StringAttr entryPoint;
169 SmallVector<Type> elementTypes;
170 };
171
172 // TODO: Use an associative array to support multiple vulkan launch calls.
173 SPIRVAttributes spirvAttributes;
174 /// The number of vulkan launch configuration operands, placed at the leading
175 /// positions of the operand list.
176 static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
177};
178
179} // namespace
180
181void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
182 initializeCachedTypes();
183
184 // Collect SPIR-V attributes such as `spirv_blob` and
185 // `spirv_entry_point_name`.
186 getOperation().walk([this](LLVM::CallOp op) {
187 if (isVulkanLaunchCallOp(op))
188 collectSPIRVAttributes(op);
189 });
190
191 // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
192 getOperation().walk([this](LLVM::CallOp op) {
193 if (isCInterfaceVulkanLaunchCallOp(op))
194 translateVulkanLaunchCall(op);
195 });
196}
197
198void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
199 LLVM::CallOp vulkanLaunchCallOp) {
200 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
201 // for the given vulkan launch call.
202 auto spirvBlobAttr =
203 vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
204 if (!spirvBlobAttr) {
205 vulkanLaunchCallOp.emitError()
206 << "missing " << kSPIRVBlobAttrName << " attribute";
207 return signalPassFailure();
208 }
209
210 auto spirvEntryPointNameAttr =
211 vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
212 if (!spirvEntryPointNameAttr) {
213 vulkanLaunchCallOp.emitError()
214 << "missing " << kSPIRVEntryPointAttrName << " attribute";
215 return signalPassFailure();
216 }
217
218 auto spirvElementTypesAttr =
219 vulkanLaunchCallOp->getAttrOfType<ArrayAttr>(kSPIRVElementTypesAttrName);
220 if (!spirvElementTypesAttr) {
221 vulkanLaunchCallOp.emitError()
222 << "missing " << kSPIRVElementTypesAttrName << " attribute";
223 return signalPassFailure();
224 }
225 if (llvm::any_of(spirvElementTypesAttr,
226 [](Attribute attr) { return !isa<TypeAttr>(Val: attr); })) {
227 vulkanLaunchCallOp.emitError()
228 << "expected " << spirvElementTypesAttr << " to be an array of types";
229 return signalPassFailure();
230 }
231
232 spirvAttributes.blob = spirvBlobAttr;
233 spirvAttributes.entryPoint = spirvEntryPointNameAttr;
234 spirvAttributes.elementTypes =
235 llvm::to_vector(spirvElementTypesAttr.getAsValueRange<mlir::TypeAttr>());
236}
237
238void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
239 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
240 if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
241 kVulkanLaunchNumConfigOperands)
242 return;
243 OpBuilder builder(cInterfaceVulkanLaunchCallOp);
244 Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
245
246 // Create LLVM constant for the descriptor set index.
247 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
248 // pass does.
249 Value descriptorSet =
250 builder.create<LLVM::ConstantOp>(loc, getInt32Type(), 0);
251
252 for (auto [index, ptrToMemRefDescriptor] :
253 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
254 kVulkanLaunchNumConfigOperands))) {
255 // Create LLVM constant for the descriptor binding index.
256 Value descriptorBinding =
257 builder.create<LLVM::ConstantOp>(loc, getInt32Type(), index);
258
259 if (index >= spirvAttributes.elementTypes.size()) {
260 cInterfaceVulkanLaunchCallOp.emitError()
261 << kSPIRVElementTypesAttrName << " missing element type for "
262 << ptrToMemRefDescriptor;
263 return signalPassFailure();
264 }
265
266 uint32_t rank = 0;
267 Type type = spirvAttributes.elementTypes[index];
268 if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
269 cInterfaceVulkanLaunchCallOp.emitError()
270 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
271 return signalPassFailure();
272 }
273
274 auto symbolName =
275 llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
276 // Create call to `bindMemRef`.
277 builder.create<LLVM::CallOp>(
278 loc, TypeRange(), StringRef(symbolName.data(), symbolName.size()),
279 ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
280 ptrToMemRefDescriptor});
281 }
282}
283
284LogicalResult
285VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value launchCallArg,
286 uint32_t &rank) {
287 // Deduce the rank from the type used to allocate the lowered MemRef.
288 auto alloca = launchCallArg.getDefiningOp<LLVM::AllocaOp>();
289 if (!alloca)
290 return failure();
291
292 std::optional<Type> elementType = alloca.getElemType();
293 assert(elementType && "expected to work with opaque pointers");
294 auto llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(Val&: *elementType);
295 // template <typename Elem, size_t Rank>
296 // struct {
297 // Elem *allocated;
298 // Elem *aligned;
299 // int64_t offset;
300 // int64_t sizes[Rank]; // omitted when rank == 0
301 // int64_t strides[Rank]; // omitted when rank == 0
302 // };
303 if (!llvmDescriptorTy)
304 return failure();
305
306 if (llvmDescriptorTy.getBody().size() == 3) {
307 rank = 0;
308 return success();
309 }
310 rank =
311 cast<LLVM::LLVMArrayType>(llvmDescriptorTy.getBody()[3]).getNumElements();
312 return success();
313}
314
315void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
316 ModuleOp module = getOperation();
317 auto builder = OpBuilder::atBlockEnd(block: module.getBody());
318
319 if (!module.lookupSymbol(kSetEntryPoint)) {
320 builder.create<LLVM::LLVMFuncOp>(
321 loc, kSetEntryPoint,
322 LLVM::LLVMFunctionType::get(getVoidType(),
323 {getPointerType(), getPointerType()}));
324 }
325
326 if (!module.lookupSymbol(kSetNumWorkGroups)) {
327 builder.create<LLVM::LLVMFuncOp>(
328 loc, kSetNumWorkGroups,
329 LLVM::LLVMFunctionType::get(getVoidType(),
330 {getPointerType(), getInt64Type(),
331 getInt64Type(), getInt64Type()}));
332 }
333
334 if (!module.lookupSymbol(kSetBinaryShader)) {
335 builder.create<LLVM::LLVMFuncOp>(
336 loc, kSetBinaryShader,
337 LLVM::LLVMFunctionType::get(
338 getVoidType(),
339 {getPointerType(), getPointerType(), getInt32Type()}));
340 }
341
342 if (!module.lookupSymbol(kRunOnVulkan)) {
343 builder.create<LLVM::LLVMFuncOp>(
344 loc, kRunOnVulkan,
345 LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
346 }
347
348 for (unsigned i = 1; i <= 3; i++) {
349 SmallVector<Type, 5> types{
350 Float32Type::get(&getContext()), IntegerType::get(&getContext(), 32),
351 IntegerType::get(&getContext(), 16), IntegerType::get(&getContext(), 8),
352 Float16Type::get(&getContext())};
353 for (auto type : types) {
354 std::string fnName = "bindMemRef" + std::to_string(val: i) + "D" +
355 std::string(stringifyType(type));
356 if (isa<Float16Type>(type))
357 type = IntegerType::get(&getContext(), 16);
358 if (!module.lookupSymbol(fnName)) {
359 auto fnType = LLVM::LLVMFunctionType::get(
360 getVoidType(),
361 {llvmPointerType, getInt32Type(), getInt32Type(), llvmPointerType},
362 /*isVarArg=*/false);
363 builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
364 }
365 }
366 }
367
368 if (!module.lookupSymbol(kInitVulkan)) {
369 builder.create<LLVM::LLVMFuncOp>(
370 loc, kInitVulkan, LLVM::LLVMFunctionType::get(getPointerType(), {}));
371 }
372
373 if (!module.lookupSymbol(kDeinitVulkan)) {
374 builder.create<LLVM::LLVMFuncOp>(
375 loc, kDeinitVulkan,
376 LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
377 }
378}
379
380Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
381 StringRef name, Location loc, OpBuilder &builder) {
382 SmallString<16> shaderName(name.begin(), name.end());
383 // Append `\0` to follow C style string given that LLVM::createGlobalString()
384 // won't handle this directly for us.
385 shaderName.push_back(Elt: '\0');
386
387 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
388 return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
389 shaderName, LLVM::Linkage::Internal);
390}
391
392void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
393 LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
394 OpBuilder builder(cInterfaceVulkanLaunchCallOp);
395 Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
396 // Create call to `initVulkan`.
397 auto initVulkanCall = builder.create<LLVM::CallOp>(
398 loc, TypeRange{getPointerType()}, kInitVulkan);
399 // The result of `initVulkan` function is a pointer to Vulkan runtime, we
400 // need to pass that pointer to each Vulkan runtime call.
401 auto vulkanRuntime = initVulkanCall.getResult();
402
403 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
404 // that data to runtime call.
405 Value ptrToSPIRVBinary = LLVM::createGlobalString(
406 loc, builder, kSPIRVBinary, spirvAttributes.blob.getValue(),
407 LLVM::Linkage::Internal);
408
409 // Create LLVM constant for the size of SPIR-V binary shader.
410 Value binarySize = builder.create<LLVM::ConstantOp>(
411 loc, getInt32Type(), spirvAttributes.blob.getValue().size());
412
413 // Create call to `bindMemRef` for each memref operand.
414 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
415
416 // Create call to `setBinaryShader` runtime function with the given pointer to
417 // SPIR-V binary and binary size.
418 builder.create<LLVM::CallOp>(
419 loc, TypeRange(), kSetBinaryShader,
420 ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
421 // Create LLVM global with entry point name.
422 Value entryPointName = createEntryPointNameConstant(
423 spirvAttributes.entryPoint.getValue(), loc, builder);
424 // Create call to `setEntryPoint` runtime function with the given pointer to
425 // entry point name.
426 builder.create<LLVM::CallOp>(loc, TypeRange(), kSetEntryPoint,
427 ValueRange{vulkanRuntime, entryPointName});
428
429 // Create number of local workgroup for each dimension.
430 builder.create<LLVM::CallOp>(
431 loc, TypeRange(), kSetNumWorkGroups,
432 ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
433 cInterfaceVulkanLaunchCallOp.getOperand(1),
434 cInterfaceVulkanLaunchCallOp.getOperand(2)});
435
436 // Create call to `runOnVulkan` runtime function.
437 builder.create<LLVM::CallOp>(loc, TypeRange(), kRunOnVulkan,
438 ValueRange{vulkanRuntime});
439
440 // Create call to 'deinitVulkan' runtime function.
441 builder.create<LLVM::CallOp>(loc, TypeRange(), kDeinitVulkan,
442 ValueRange{vulkanRuntime});
443
444 // Declare runtime functions.
445 declareVulkanFunctions(loc);
446
447 cInterfaceVulkanLaunchCallOp.erase();
448}
449

source code of mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp