1 | //===- mlir-spirv-cpu-runner.cpp - MLIR SPIR-V Execution on CPU -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Main entry point to a command line utility that executes an MLIR file on the |
10 | // CPU by translating MLIR GPU module and host part to LLVM IR before |
11 | // JIT-compiling and executing. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" |
16 | #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" |
17 | #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h" |
18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
20 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
21 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
22 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
23 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
24 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
25 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
26 | #include "mlir/Dialect/SPIRV/Transforms/Passes.h" |
27 | #include "mlir/ExecutionEngine/JitRunner.h" |
28 | #include "mlir/ExecutionEngine/OptUtils.h" |
29 | #include "mlir/Pass/Pass.h" |
30 | #include "mlir/Pass/PassManager.h" |
31 | #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" |
32 | #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" |
33 | #include "mlir/Target/LLVMIR/Export.h" |
34 | |
35 | #include "llvm/IR/LLVMContext.h" |
36 | #include "llvm/IR/Module.h" |
37 | #include "llvm/Linker/Linker.h" |
38 | #include "llvm/Support/InitLLVM.h" |
39 | #include "llvm/Support/TargetSelect.h" |
40 | |
41 | using namespace mlir; |
42 | |
43 | /// A utility function that builds llvm::Module from two nested MLIR modules. |
44 | /// |
45 | /// module @main { |
46 | /// module @kernel { |
47 | /// // Some ops |
48 | /// } |
49 | /// // Some other ops |
50 | /// } |
51 | /// |
52 | /// Each of these two modules is translated to LLVM IR module, then they are |
53 | /// linked together and returned. |
54 | static std::unique_ptr<llvm::Module> |
55 | convertMLIRModule(Operation *op, llvm::LLVMContext &context) { |
56 | auto module = dyn_cast<ModuleOp>(op); |
57 | if (!module) |
58 | return op->emitError(message: "op must be a 'builtin.module" ), nullptr; |
59 | // Verify that there is only one nested module. |
60 | auto modules = module.getOps<ModuleOp>(); |
61 | if (!llvm::hasSingleElement(modules)) { |
62 | module.emitError("The module must contain exactly one nested module" ); |
63 | return nullptr; |
64 | } |
65 | |
66 | // Translate nested module and erase it. |
67 | ModuleOp nested = *modules.begin(); |
68 | std::unique_ptr<llvm::Module> kernelModule = |
69 | translateModuleToLLVMIR(nested, context); |
70 | nested.erase(); |
71 | |
72 | std::unique_ptr<llvm::Module> mainModule = |
73 | translateModuleToLLVMIR(module, context); |
74 | llvm::Linker::linkModules(Dest&: *mainModule, Src: std::move(kernelModule)); |
75 | return mainModule; |
76 | } |
77 | |
78 | static LogicalResult runMLIRPasses(Operation *module, |
79 | JitRunnerOptions &options) { |
80 | PassManager passManager(module->getContext(), |
81 | module->getName().getStringRef()); |
82 | if (failed(result: applyPassManagerCLOptions(pm&: passManager))) |
83 | return failure(); |
84 | passManager.addPass(createGpuKernelOutliningPass()); |
85 | passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true)); |
86 | |
87 | OpPassManager &nestedPM = passManager.nest<spirv::ModuleOp>(); |
88 | nestedPM.addPass(spirv::createSPIRVLowerABIAttributesPass()); |
89 | nestedPM.addPass(spirv::createSPIRVUpdateVCEPass()); |
90 | passManager.addPass(pass: createLowerHostCodeToLLVMPass()); |
91 | passManager.addPass(pass: createConvertSPIRVToLLVMPass()); |
92 | return passManager.run(op: module); |
93 | } |
94 | |
95 | int main(int argc, char **argv) { |
96 | llvm::InitLLVM y(argc, argv); |
97 | |
98 | llvm::InitializeNativeTarget(); |
99 | llvm::InitializeNativeTargetAsmPrinter(); |
100 | |
101 | mlir::JitRunnerConfig jitRunnerConfig; |
102 | jitRunnerConfig.mlirTransformer = runMLIRPasses; |
103 | jitRunnerConfig.llvmModuleBuilder = convertMLIRModule; |
104 | |
105 | mlir::DialectRegistry registry; |
106 | registry.insert<mlir::arith::ArithDialect, mlir::LLVM::LLVMDialect, |
107 | mlir::gpu::GPUDialect, mlir::spirv::SPIRVDialect, |
108 | mlir::func::FuncDialect, mlir::memref::MemRefDialect>(); |
109 | mlir::registerPassManagerCLOptions(); |
110 | mlir::registerBuiltinDialectTranslation(registry); |
111 | mlir::registerLLVMDialectTranslation(registry); |
112 | |
113 | return mlir::JitRunnerMain(argc, argv, registry, config: jitRunnerConfig); |
114 | } |
115 | |