1 | //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to convert vulkan launch call into a sequence of |
10 | // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we |
11 | // don't expose separate external functions in IR for each of them, instead we |
12 | // expose a few external functions to wrapper libraries which manages Vulkan |
13 | // runtime. |
14 | // |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" |
18 | |
19 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
20 | #include "mlir/IR/Attributes.h" |
21 | #include "mlir/IR/Builders.h" |
22 | #include "mlir/IR/BuiltinOps.h" |
23 | #include "mlir/Pass/Pass.h" |
24 | #include "llvm/ADT/SmallString.h" |
25 | #include "llvm/Support/FormatVariadic.h" |
26 | |
27 | namespace mlir { |
28 | #define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS |
29 | #include "mlir/Conversion/Passes.h.inc" |
30 | } // namespace mlir |
31 | |
32 | using namespace mlir; |
33 | |
34 | static constexpr const char *kCInterfaceVulkanLaunch = |
35 | "_mlir_ciface_vulkanLaunch" ; |
36 | static constexpr const char *kDeinitVulkan = "deinitVulkan" ; |
37 | static constexpr const char *kRunOnVulkan = "runOnVulkan" ; |
38 | static constexpr const char *kInitVulkan = "initVulkan" ; |
39 | static constexpr const char *kSetBinaryShader = "setBinaryShader" ; |
40 | static constexpr const char *kSetEntryPoint = "setEntryPoint" ; |
41 | static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups" ; |
42 | static constexpr const char *kSPIRVBinary = "SPIRV_BIN" ; |
43 | static constexpr const char *kSPIRVBlobAttrName = "spirv_blob" ; |
44 | static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point" ; |
45 | static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types" ; |
46 | static constexpr const char *kVulkanLaunch = "vulkanLaunch" ; |
47 | |
48 | namespace { |
49 | |
50 | /// A pass to convert vulkan launch call op into a sequence of Vulkan |
51 | /// runtime calls in the following order: |
52 | /// |
53 | /// * initVulkan -- initializes vulkan runtime |
54 | /// * bindMemRef -- binds memref |
55 | /// * setBinaryShader -- sets the binary shader data |
56 | /// * setEntryPoint -- sets the entry point name |
57 | /// * setNumWorkGroups -- sets the number of a local workgroups |
58 | /// * runOnVulkan -- runs vulkan runtime |
59 | /// * deinitVulkan -- deinitializes vulkan runtime |
60 | /// |
61 | class VulkanLaunchFuncToVulkanCallsPass |
62 | : public impl::ConvertVulkanLaunchFuncToVulkanCallsPassBase< |
63 | VulkanLaunchFuncToVulkanCallsPass> { |
64 | private: |
65 | void initializeCachedTypes() { |
66 | llvmFloatType = Float32Type::get(&getContext()); |
67 | llvmVoidType = LLVM::LLVMVoidType::get(&getContext()); |
68 | llvmPointerType = LLVM::LLVMPointerType::get(&getContext()); |
69 | llvmInt32Type = IntegerType::get(&getContext(), 32); |
70 | llvmInt64Type = IntegerType::get(&getContext(), 64); |
71 | } |
72 | |
73 | Type getMemRefType(uint32_t rank, Type elemenType) { |
74 | // According to the MLIR doc memref argument is converted into a |
75 | // pointer-to-struct argument of type: |
76 | // template <typename Elem, size_t Rank> |
77 | // struct { |
78 | // Elem *allocated; |
79 | // Elem *aligned; |
80 | // int64_t offset; |
81 | // int64_t sizes[Rank]; // omitted when rank == 0 |
82 | // int64_t strides[Rank]; // omitted when rank == 0 |
83 | // }; |
84 | auto llvmArrayRankElementSizeType = |
85 | LLVM::LLVMArrayType::get(getInt64Type(), rank); |
86 | |
87 | // Create a type |
88 | // `!llvm<"{ `element-type`*, `element-type`*, i64, |
89 | // [`rank` x i64], [`rank` x i64]}">`. |
90 | return LLVM::LLVMStructType::getLiteral( |
91 | context: &getContext(), |
92 | types: {llvmPointerType, llvmPointerType, getInt64Type(), |
93 | llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); |
94 | } |
95 | |
96 | Type getVoidType() { return llvmVoidType; } |
97 | Type getPointerType() { return llvmPointerType; } |
98 | Type getInt32Type() { return llvmInt32Type; } |
99 | Type getInt64Type() { return llvmInt64Type; } |
100 | |
101 | /// Creates an LLVM global for the given `name`. |
102 | Value createEntryPointNameConstant(StringRef name, Location loc, |
103 | OpBuilder &builder); |
104 | |
105 | /// Declares all needed runtime functions. |
106 | void declareVulkanFunctions(Location loc); |
107 | |
108 | /// Checks whether the given LLVM::CallOp is a vulkan launch call op. |
109 | bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { |
110 | return (callOp.getCallee() && *callOp.getCallee() == kVulkanLaunch && |
111 | callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); |
112 | } |
113 | |
114 | /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call |
115 | /// op. |
116 | bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { |
117 | return (callOp.getCallee() && |
118 | *callOp.getCallee() == kCInterfaceVulkanLaunch && |
119 | callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); |
120 | } |
121 | |
122 | /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan |
123 | /// runtime calls. |
124 | void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); |
125 | |
126 | /// Creates call to `bindMemRef` for each memref operand. |
127 | void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, |
128 | Value vulkanRuntime); |
129 | |
130 | /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. |
131 | void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); |
132 | |
133 | /// Deduces a rank from the given 'launchCallArg`. |
134 | LogicalResult deduceMemRefRank(Value launchCallArg, uint32_t &rank); |
135 | |
136 | /// Returns a string representation from the given `type`. |
137 | StringRef stringifyType(Type type) { |
138 | if (isa<Float32Type>(type)) |
139 | return "Float" ; |
140 | if (isa<Float16Type>(type)) |
141 | return "Half" ; |
142 | if (auto intType = dyn_cast<IntegerType>(type)) { |
143 | if (intType.getWidth() == 32) |
144 | return "Int32" ; |
145 | if (intType.getWidth() == 16) |
146 | return "Int16" ; |
147 | if (intType.getWidth() == 8) |
148 | return "Int8" ; |
149 | } |
150 | |
151 | llvm_unreachable("unsupported type" ); |
152 | } |
153 | |
154 | public: |
155 | using Base::Base; |
156 | |
157 | void runOnOperation() override; |
158 | |
159 | private: |
160 | Type llvmFloatType; |
161 | Type llvmVoidType; |
162 | Type llvmPointerType; |
163 | Type llvmInt32Type; |
164 | Type llvmInt64Type; |
165 | |
166 | struct SPIRVAttributes { |
167 | StringAttr blob; |
168 | StringAttr entryPoint; |
169 | SmallVector<Type> elementTypes; |
170 | }; |
171 | |
172 | // TODO: Use an associative array to support multiple vulkan launch calls. |
173 | SPIRVAttributes spirvAttributes; |
174 | /// The number of vulkan launch configuration operands, placed at the leading |
175 | /// positions of the operand list. |
176 | static constexpr unsigned kVulkanLaunchNumConfigOperands = 3; |
177 | }; |
178 | |
179 | } // namespace |
180 | |
181 | void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() { |
182 | initializeCachedTypes(); |
183 | |
184 | // Collect SPIR-V attributes such as `spirv_blob` and |
185 | // `spirv_entry_point_name`. |
186 | getOperation().walk([this](LLVM::CallOp op) { |
187 | if (isVulkanLaunchCallOp(op)) |
188 | collectSPIRVAttributes(op); |
189 | }); |
190 | |
191 | // Convert vulkan launch call op into a sequence of Vulkan runtime calls. |
192 | getOperation().walk([this](LLVM::CallOp op) { |
193 | if (isCInterfaceVulkanLaunchCallOp(op)) |
194 | translateVulkanLaunchCall(op); |
195 | }); |
196 | } |
197 | |
198 | void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( |
199 | LLVM::CallOp vulkanLaunchCallOp) { |
200 | // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes |
201 | // for the given vulkan launch call. |
202 | auto spirvBlobAttr = |
203 | vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVBlobAttrName); |
204 | if (!spirvBlobAttr) { |
205 | vulkanLaunchCallOp.emitError() |
206 | << "missing " << kSPIRVBlobAttrName << " attribute" ; |
207 | return signalPassFailure(); |
208 | } |
209 | |
210 | auto spirvEntryPointNameAttr = |
211 | vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); |
212 | if (!spirvEntryPointNameAttr) { |
213 | vulkanLaunchCallOp.emitError() |
214 | << "missing " << kSPIRVEntryPointAttrName << " attribute" ; |
215 | return signalPassFailure(); |
216 | } |
217 | |
218 | auto spirvElementTypesAttr = |
219 | vulkanLaunchCallOp->getAttrOfType<ArrayAttr>(kSPIRVElementTypesAttrName); |
220 | if (!spirvElementTypesAttr) { |
221 | vulkanLaunchCallOp.emitError() |
222 | << "missing " << kSPIRVElementTypesAttrName << " attribute" ; |
223 | return signalPassFailure(); |
224 | } |
225 | if (llvm::any_of(spirvElementTypesAttr, |
226 | [](Attribute attr) { return !isa<TypeAttr>(Val: attr); })) { |
227 | vulkanLaunchCallOp.emitError() |
228 | << "expected " << spirvElementTypesAttr << " to be an array of types" ; |
229 | return signalPassFailure(); |
230 | } |
231 | |
232 | spirvAttributes.blob = spirvBlobAttr; |
233 | spirvAttributes.entryPoint = spirvEntryPointNameAttr; |
234 | spirvAttributes.elementTypes = |
235 | llvm::to_vector(spirvElementTypesAttr.getAsValueRange<mlir::TypeAttr>()); |
236 | } |
237 | |
238 | void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( |
239 | LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { |
240 | if (cInterfaceVulkanLaunchCallOp.getNumOperands() == |
241 | kVulkanLaunchNumConfigOperands) |
242 | return; |
243 | OpBuilder builder(cInterfaceVulkanLaunchCallOp); |
244 | Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); |
245 | |
246 | // Create LLVM constant for the descriptor set index. |
247 | // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` |
248 | // pass does. |
249 | Value descriptorSet = |
250 | builder.create<LLVM::ConstantOp>(loc, getInt32Type(), 0); |
251 | |
252 | for (auto [index, ptrToMemRefDescriptor] : |
253 | llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( |
254 | kVulkanLaunchNumConfigOperands))) { |
255 | // Create LLVM constant for the descriptor binding index. |
256 | Value descriptorBinding = |
257 | builder.create<LLVM::ConstantOp>(loc, getInt32Type(), index); |
258 | |
259 | if (index >= spirvAttributes.elementTypes.size()) { |
260 | cInterfaceVulkanLaunchCallOp.emitError() |
261 | << kSPIRVElementTypesAttrName << " missing element type for " |
262 | << ptrToMemRefDescriptor; |
263 | return signalPassFailure(); |
264 | } |
265 | |
266 | uint32_t rank = 0; |
267 | Type type = spirvAttributes.elementTypes[index]; |
268 | if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) { |
269 | cInterfaceVulkanLaunchCallOp.emitError() |
270 | << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); |
271 | return signalPassFailure(); |
272 | } |
273 | |
274 | auto symbolName = |
275 | llvm::formatv("bindMemRef{0}D{1}" , rank, stringifyType(type)).str(); |
276 | // Create call to `bindMemRef`. |
277 | builder.create<LLVM::CallOp>( |
278 | loc, TypeRange(), StringRef(symbolName.data(), symbolName.size()), |
279 | ValueRange{vulkanRuntime, descriptorSet, descriptorBinding, |
280 | ptrToMemRefDescriptor}); |
281 | } |
282 | } |
283 | |
284 | LogicalResult |
285 | VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value launchCallArg, |
286 | uint32_t &rank) { |
287 | // Deduce the rank from the type used to allocate the lowered MemRef. |
288 | auto alloca = launchCallArg.getDefiningOp<LLVM::AllocaOp>(); |
289 | if (!alloca) |
290 | return failure(); |
291 | |
292 | std::optional<Type> elementType = alloca.getElemType(); |
293 | assert(elementType && "expected to work with opaque pointers" ); |
294 | auto llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(Val&: *elementType); |
295 | // template <typename Elem, size_t Rank> |
296 | // struct { |
297 | // Elem *allocated; |
298 | // Elem *aligned; |
299 | // int64_t offset; |
300 | // int64_t sizes[Rank]; // omitted when rank == 0 |
301 | // int64_t strides[Rank]; // omitted when rank == 0 |
302 | // }; |
303 | if (!llvmDescriptorTy) |
304 | return failure(); |
305 | |
306 | if (llvmDescriptorTy.getBody().size() == 3) { |
307 | rank = 0; |
308 | return success(); |
309 | } |
310 | rank = |
311 | cast<LLVM::LLVMArrayType>(llvmDescriptorTy.getBody()[3]).getNumElements(); |
312 | return success(); |
313 | } |
314 | |
315 | void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { |
316 | ModuleOp module = getOperation(); |
317 | auto builder = OpBuilder::atBlockEnd(block: module.getBody()); |
318 | |
319 | if (!module.lookupSymbol(kSetEntryPoint)) { |
320 | builder.create<LLVM::LLVMFuncOp>( |
321 | loc, kSetEntryPoint, |
322 | LLVM::LLVMFunctionType::get(getVoidType(), |
323 | {getPointerType(), getPointerType()})); |
324 | } |
325 | |
326 | if (!module.lookupSymbol(kSetNumWorkGroups)) { |
327 | builder.create<LLVM::LLVMFuncOp>( |
328 | loc, kSetNumWorkGroups, |
329 | LLVM::LLVMFunctionType::get(getVoidType(), |
330 | {getPointerType(), getInt64Type(), |
331 | getInt64Type(), getInt64Type()})); |
332 | } |
333 | |
334 | if (!module.lookupSymbol(kSetBinaryShader)) { |
335 | builder.create<LLVM::LLVMFuncOp>( |
336 | loc, kSetBinaryShader, |
337 | LLVM::LLVMFunctionType::get( |
338 | getVoidType(), |
339 | {getPointerType(), getPointerType(), getInt32Type()})); |
340 | } |
341 | |
342 | if (!module.lookupSymbol(kRunOnVulkan)) { |
343 | builder.create<LLVM::LLVMFuncOp>( |
344 | loc, kRunOnVulkan, |
345 | LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()})); |
346 | } |
347 | |
348 | for (unsigned i = 1; i <= 3; i++) { |
349 | SmallVector<Type, 5> types{ |
350 | Float32Type::get(&getContext()), IntegerType::get(&getContext(), 32), |
351 | IntegerType::get(&getContext(), 16), IntegerType::get(&getContext(), 8), |
352 | Float16Type::get(&getContext())}; |
353 | for (auto type : types) { |
354 | std::string fnName = "bindMemRef" + std::to_string(val: i) + "D" + |
355 | std::string(stringifyType(type)); |
356 | if (isa<Float16Type>(type)) |
357 | type = IntegerType::get(&getContext(), 16); |
358 | if (!module.lookupSymbol(fnName)) { |
359 | auto fnType = LLVM::LLVMFunctionType::get( |
360 | getVoidType(), |
361 | {llvmPointerType, getInt32Type(), getInt32Type(), llvmPointerType}, |
362 | /*isVarArg=*/false); |
363 | builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType); |
364 | } |
365 | } |
366 | } |
367 | |
368 | if (!module.lookupSymbol(kInitVulkan)) { |
369 | builder.create<LLVM::LLVMFuncOp>( |
370 | loc, kInitVulkan, LLVM::LLVMFunctionType::get(getPointerType(), {})); |
371 | } |
372 | |
373 | if (!module.lookupSymbol(kDeinitVulkan)) { |
374 | builder.create<LLVM::LLVMFuncOp>( |
375 | loc, kDeinitVulkan, |
376 | LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()})); |
377 | } |
378 | } |
379 | |
380 | Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( |
381 | StringRef name, Location loc, OpBuilder &builder) { |
382 | SmallString<16> shaderName(name.begin(), name.end()); |
383 | // Append `\0` to follow C style string given that LLVM::createGlobalString() |
384 | // won't handle this directly for us. |
385 | shaderName.push_back(Elt: '\0'); |
386 | |
387 | std::string entryPointGlobalName = (name + "_spv_entry_point_name" ).str(); |
388 | return LLVM::createGlobalString(loc, builder, entryPointGlobalName, |
389 | shaderName, LLVM::Linkage::Internal); |
390 | } |
391 | |
392 | void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( |
393 | LLVM::CallOp cInterfaceVulkanLaunchCallOp) { |
394 | OpBuilder builder(cInterfaceVulkanLaunchCallOp); |
395 | Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); |
396 | // Create call to `initVulkan`. |
397 | auto initVulkanCall = builder.create<LLVM::CallOp>( |
398 | loc, TypeRange{getPointerType()}, kInitVulkan); |
399 | // The result of `initVulkan` function is a pointer to Vulkan runtime, we |
400 | // need to pass that pointer to each Vulkan runtime call. |
401 | auto vulkanRuntime = initVulkanCall.getResult(); |
402 | |
403 | // Create LLVM global with SPIR-V binary data, so we can pass a pointer with |
404 | // that data to runtime call. |
405 | Value ptrToSPIRVBinary = LLVM::createGlobalString( |
406 | loc, builder, kSPIRVBinary, spirvAttributes.blob.getValue(), |
407 | LLVM::Linkage::Internal); |
408 | |
409 | // Create LLVM constant for the size of SPIR-V binary shader. |
410 | Value binarySize = builder.create<LLVM::ConstantOp>( |
411 | loc, getInt32Type(), spirvAttributes.blob.getValue().size()); |
412 | |
413 | // Create call to `bindMemRef` for each memref operand. |
414 | createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); |
415 | |
416 | // Create call to `setBinaryShader` runtime function with the given pointer to |
417 | // SPIR-V binary and binary size. |
418 | builder.create<LLVM::CallOp>( |
419 | loc, TypeRange(), kSetBinaryShader, |
420 | ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize}); |
421 | // Create LLVM global with entry point name. |
422 | Value entryPointName = createEntryPointNameConstant( |
423 | spirvAttributes.entryPoint.getValue(), loc, builder); |
424 | // Create call to `setEntryPoint` runtime function with the given pointer to |
425 | // entry point name. |
426 | builder.create<LLVM::CallOp>(loc, TypeRange(), kSetEntryPoint, |
427 | ValueRange{vulkanRuntime, entryPointName}); |
428 | |
429 | // Create number of local workgroup for each dimension. |
430 | builder.create<LLVM::CallOp>( |
431 | loc, TypeRange(), kSetNumWorkGroups, |
432 | ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), |
433 | cInterfaceVulkanLaunchCallOp.getOperand(1), |
434 | cInterfaceVulkanLaunchCallOp.getOperand(2)}); |
435 | |
436 | // Create call to `runOnVulkan` runtime function. |
437 | builder.create<LLVM::CallOp>(loc, TypeRange(), kRunOnVulkan, |
438 | ValueRange{vulkanRuntime}); |
439 | |
440 | // Create call to 'deinitVulkan' runtime function. |
441 | builder.create<LLVM::CallOp>(loc, TypeRange(), kDeinitVulkan, |
442 | ValueRange{vulkanRuntime}); |
443 | |
444 | // Declare runtime functions. |
445 | declareVulkanFunctions(loc); |
446 | |
447 | cInterfaceVulkanLaunchCallOp.erase(); |
448 | } |
449 | |