1 | //===- mlir-vulkan-runner.cpp - MLIR Vulkan Execution Driver --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This is a command line utility that executes an MLIR file on the Vulkan by |
10 | // translating MLIR GPU module to SPIR-V and host part to LLVM IR before |
11 | // JIT-compiling and executing the latter. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" |
16 | #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" |
17 | #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" |
18 | #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" |
19 | #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
20 | #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" |
21 | #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" |
22 | #include "mlir/Dialect/Arith/IR/Arith.h" |
23 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
24 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
25 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
26 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
27 | #include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h" |
28 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
29 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
30 | #include "mlir/Dialect/SCF/IR/SCF.h" |
31 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
32 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
33 | #include "mlir/Dialect/SPIRV/Transforms/Passes.h" |
34 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
35 | #include "mlir/ExecutionEngine/JitRunner.h" |
36 | #include "mlir/Pass/Pass.h" |
37 | #include "mlir/Pass/PassManager.h" |
38 | #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" |
39 | #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" |
40 | #include "llvm/Support/InitLLVM.h" |
41 | #include "llvm/Support/TargetSelect.h" |
42 | |
43 | using namespace mlir; |
44 | |
45 | namespace { |
46 | struct VulkanRunnerOptions { |
47 | llvm::cl::OptionCategory category{"mlir-vulkan-runner options" }; |
48 | llvm::cl::opt<bool> spirvWebGPUPrepare{ |
49 | "vulkan-runner-spirv-webgpu-prepare" , |
50 | llvm::cl::desc("Run MLIR transforms used when targetting WebGPU" ), |
51 | llvm::cl::cat(category)}; |
52 | }; |
53 | } // namespace |
54 | |
55 | static LogicalResult runMLIRPasses(Operation *op, |
56 | VulkanRunnerOptions &options) { |
57 | auto module = dyn_cast<ModuleOp>(op); |
58 | if (!module) |
59 | return op->emitOpError(message: "expected a 'builtin.module' op" ); |
60 | PassManager passManager(module.getContext()); |
61 | if (failed(result: applyPassManagerCLOptions(pm&: passManager))) |
62 | return failure(); |
63 | |
64 | passManager.addPass(createGpuKernelOutliningPass()); |
65 | passManager.addPass(pass: memref::createFoldMemRefAliasOpsPass()); |
66 | |
67 | passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true)); |
68 | OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>(); |
69 | modulePM.addPass(spirv::createSPIRVLowerABIAttributesPass()); |
70 | modulePM.addPass(spirv::createSPIRVUpdateVCEPass()); |
71 | if (options.spirvWebGPUPrepare) |
72 | modulePM.addPass(spirv::createSPIRVWebGPUPreparePass()); |
73 | |
74 | passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass()); |
75 | passManager.addPass(pass: createFinalizeMemRefToLLVMConversionPass()); |
76 | passManager.addPass(pass: createConvertVectorToLLVMPass()); |
77 | passManager.nest<func::FuncOp>().addPass(pass: LLVM::createRequestCWrappersPass()); |
78 | ConvertFuncToLLVMPassOptions funcToLLVMOptions{}; |
79 | funcToLLVMOptions.indexBitwidth = |
80 | DataLayout(module).getTypeSizeInBits(IndexType::get(module.getContext())); |
81 | passManager.addPass(pass: createConvertFuncToLLVMPass(funcToLLVMOptions)); |
82 | passManager.addPass(pass: createReconcileUnrealizedCastsPass()); |
83 | passManager.addPass(pass: createConvertVulkanLaunchFuncToVulkanCallsPass()); |
84 | |
85 | return passManager.run(op: module); |
86 | } |
87 | |
88 | int main(int argc, char **argv) { |
89 | llvm::llvm_shutdown_obj x; |
90 | registerPassManagerCLOptions(); |
91 | |
92 | llvm::InitLLVM y(argc, argv); |
93 | llvm::InitializeNativeTarget(); |
94 | llvm::InitializeNativeTargetAsmPrinter(); |
95 | |
96 | // Initialize runner-specific CLI options. These will be parsed and |
97 | // initialzied in `JitRunnerMain`. |
98 | VulkanRunnerOptions options; |
99 | auto runPassesWithOptions = [&options](Operation *op, JitRunnerOptions &) { |
100 | return runMLIRPasses(op, options); |
101 | }; |
102 | |
103 | mlir::JitRunnerConfig jitRunnerConfig; |
104 | jitRunnerConfig.mlirTransformer = runPassesWithOptions; |
105 | |
106 | mlir::DialectRegistry registry; |
107 | registry.insert<mlir::arith::ArithDialect, mlir::LLVM::LLVMDialect, |
108 | mlir::gpu::GPUDialect, mlir::spirv::SPIRVDialect, |
109 | mlir::scf::SCFDialect, mlir::func::FuncDialect, |
110 | mlir::memref::MemRefDialect, mlir::vector::VectorDialect>(); |
111 | mlir::registerBuiltinDialectTranslation(registry); |
112 | mlir::registerLLVMDialectTranslation(registry); |
113 | |
114 | return mlir::JitRunnerMain(argc, argv, registry, config: jitRunnerConfig); |
115 | } |
116 | |