1 | //===- ModuleToBinary.cpp - Transforms GPU modules to GPU binaries ----------=// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the `GpuModuleToBinaryPass` pass, transforming GPU |
10 | // modules into GPU binaries. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
15 | |
16 | #include "mlir/Config/mlir-config.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
19 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
20 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
21 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
22 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
23 | #include "mlir/IR/BuiltinOps.h" |
24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
25 | |
26 | #include "llvm/ADT/STLExtras.h" |
27 | #include "llvm/ADT/StringSwitch.h" |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::gpu; |
31 | |
32 | namespace mlir { |
33 | #define GEN_PASS_DEF_GPUMODULETOBINARYPASS |
34 | #include "mlir/Dialect/GPU/Transforms/Passes.h.inc" |
35 | } // namespace mlir |
36 | |
37 | namespace { |
38 | class GpuModuleToBinaryPass |
39 | : public impl::GpuModuleToBinaryPassBase<GpuModuleToBinaryPass> { |
40 | public: |
41 | using Base::Base; |
42 | void getDependentDialects(DialectRegistry ®istry) const override; |
43 | void runOnOperation() final; |
44 | }; |
45 | } // namespace |
46 | |
47 | void GpuModuleToBinaryPass::getDependentDialects( |
48 | DialectRegistry ®istry) const { |
49 | // Register all GPU related translations. |
50 | registry.insert<gpu::GPUDialect>(); |
51 | registry.insert<LLVM::LLVMDialect>(); |
52 | #if MLIR_ENABLE_CUDA_CONVERSIONS |
53 | registry.insert<NVVM::NVVMDialect>(); |
54 | #endif |
55 | #if MLIR_ENABLE_ROCM_CONVERSIONS |
56 | registry.insert<ROCDL::ROCDLDialect>(); |
57 | #endif |
58 | registry.insert<spirv::SPIRVDialect>(); |
59 | } |
60 | |
61 | void GpuModuleToBinaryPass::runOnOperation() { |
62 | RewritePatternSet patterns(&getContext()); |
63 | auto targetFormat = |
64 | llvm::StringSwitch<std::optional<CompilationTarget>>(compilationTarget) |
65 | .Cases("offloading" , "llvm" , CompilationTarget::Offload) |
66 | .Cases("assembly" , "isa" , CompilationTarget::Assembly) |
67 | .Cases("binary" , "bin" , CompilationTarget::Binary) |
68 | .Cases("fatbinary" , "fatbin" , CompilationTarget::Fatbin) |
69 | .Default(std::nullopt); |
70 | if (!targetFormat) |
71 | getOperation()->emitError() << "Invalid format specified." ; |
72 | |
73 | // Lazy symbol table builder callback. |
74 | std::optional<SymbolTable> parentTable; |
75 | auto lazyTableBuilder = [&]() -> SymbolTable * { |
76 | // Build the table if it has not been built. |
77 | if (!parentTable) { |
78 | Operation *table = SymbolTable::getNearestSymbolTable(from: getOperation()); |
79 | // It's up to the target attribute to determine if failing to find a |
80 | // symbol table is an error. |
81 | if (!table) |
82 | return nullptr; |
83 | parentTable = SymbolTable(table); |
84 | } |
85 | return &parentTable.value(); |
86 | }; |
87 | |
88 | TargetOptions targetOptions(toolkitPath, linkFiles, cmdOptions, *targetFormat, |
89 | lazyTableBuilder); |
90 | if (failed(transformGpuModulesToBinaries( |
91 | getOperation(), |
92 | offloadingHandler ? dyn_cast<OffloadingLLVMTranslationAttrInterface>( |
93 | offloadingHandler.getValue()) |
94 | : OffloadingLLVMTranslationAttrInterface(nullptr), |
95 | targetOptions))) |
96 | return signalPassFailure(); |
97 | } |
98 | |
99 | namespace { |
100 | LogicalResult moduleSerializer(GPUModuleOp op, |
101 | OffloadingLLVMTranslationAttrInterface handler, |
102 | const TargetOptions &targetOptions) { |
103 | OpBuilder builder(op->getContext()); |
104 | SmallVector<Attribute> objects; |
105 | // Fail if there are no target attributes |
106 | if (!op.getTargetsAttr()) |
107 | return op.emitError("the module has no target attributes" ); |
108 | // Serialize all targets. |
109 | for (auto targetAttr : op.getTargetsAttr()) { |
110 | assert(targetAttr && "Target attribute cannot be null." ); |
111 | auto target = dyn_cast<gpu::TargetAttrInterface>(targetAttr); |
112 | assert(target && |
113 | "Target attribute doesn't implements `TargetAttrInterface`." ); |
114 | std::optional<SmallVector<char, 0>> serializedModule = |
115 | target.serializeToObject(op, targetOptions); |
116 | if (!serializedModule) { |
117 | op.emitError("An error happened while serializing the module." ); |
118 | return failure(); |
119 | } |
120 | |
121 | Attribute object = target.createObject(*serializedModule, targetOptions); |
122 | if (!object) { |
123 | op.emitError("An error happened while creating the object." ); |
124 | return failure(); |
125 | } |
126 | objects.push_back(object); |
127 | } |
128 | if (auto moduleHandler = |
129 | dyn_cast_or_null<OffloadingLLVMTranslationAttrInterface>( |
130 | op.getOffloadingHandlerAttr()); |
131 | !handler && moduleHandler) |
132 | handler = moduleHandler; |
133 | builder.setInsertionPointAfter(op); |
134 | builder.create<gpu::BinaryOp>(op.getLoc(), op.getName(), handler, |
135 | builder.getArrayAttr(objects)); |
136 | op->erase(); |
137 | return success(); |
138 | } |
139 | } // namespace |
140 | |
141 | LogicalResult mlir::gpu::transformGpuModulesToBinaries( |
142 | Operation *op, OffloadingLLVMTranslationAttrInterface handler, |
143 | const gpu::TargetOptions &targetOptions) { |
144 | for (Region ®ion : op->getRegions()) |
145 | for (Block &block : region.getBlocks()) |
146 | for (auto module : |
147 | llvm::make_early_inc_range(block.getOps<GPUModuleOp>())) |
148 | if (failed(moduleSerializer(module, handler, targetOptions))) |
149 | return failure(); |
150 | return success(); |
151 | } |
152 | |