1 | //===- ModuleToBinary.cpp - Transforms GPU modules to GPU binaries ----------=// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the `GpuModuleToBinaryPass` pass, transforming GPU |
10 | // modules into GPU binaries. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
15 | |
16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
18 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
19 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
20 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
21 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
22 | #include "mlir/IR/BuiltinOps.h" |
23 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
24 | |
25 | #include "llvm/ADT/STLExtras.h" |
26 | #include "llvm/ADT/StringSwitch.h" |
27 | |
28 | using namespace mlir; |
29 | using namespace mlir::gpu; |
30 | |
31 | namespace mlir { |
32 | #define GEN_PASS_DEF_GPUMODULETOBINARYPASS |
33 | #include "mlir/Dialect/GPU/Transforms/Passes.h.inc" |
34 | } // namespace mlir |
35 | |
36 | namespace { |
37 | class GpuModuleToBinaryPass |
38 | : public impl::GpuModuleToBinaryPassBase<GpuModuleToBinaryPass> { |
39 | public: |
40 | using Base::Base; |
41 | void runOnOperation() final; |
42 | }; |
43 | } // namespace |
44 | |
45 | void GpuModuleToBinaryPass::runOnOperation() { |
46 | RewritePatternSet patterns(&getContext()); |
47 | auto targetFormat = |
48 | llvm::StringSwitch<std::optional<CompilationTarget>>(compilationTarget) |
49 | .Cases("offloading" , "llvm" , CompilationTarget::Offload) |
50 | .Cases("assembly" , "isa" , CompilationTarget::Assembly) |
51 | .Cases("binary" , "bin" , CompilationTarget::Binary) |
52 | .Cases("fatbinary" , "fatbin" , CompilationTarget::Fatbin) |
53 | .Default(std::nullopt); |
54 | if (!targetFormat) |
55 | getOperation()->emitError() << "Invalid format specified." ; |
56 | |
57 | // Lazy symbol table builder callback. |
58 | std::optional<SymbolTable> parentTable; |
59 | auto lazyTableBuilder = [&]() -> SymbolTable * { |
60 | // Build the table if it has not been built. |
61 | if (!parentTable) { |
62 | Operation *table = SymbolTable::getNearestSymbolTable(from: getOperation()); |
63 | // It's up to the target attribute to determine if failing to find a |
64 | // symbol table is an error. |
65 | if (!table) |
66 | return nullptr; |
67 | parentTable = SymbolTable(table); |
68 | } |
69 | return &parentTable.value(); |
70 | }; |
71 | SmallVector<Attribute> librariesToLink; |
72 | for (const std::string &path : linkFiles) |
73 | librariesToLink.push_back(StringAttr::get(&getContext(), path)); |
74 | TargetOptions targetOptions(toolkitPath, librariesToLink, cmdOptions, |
75 | elfSection, *targetFormat, lazyTableBuilder); |
76 | if (failed(transformGpuModulesToBinaries( |
77 | getOperation(), OffloadingLLVMTranslationAttrInterface(nullptr), |
78 | targetOptions))) |
79 | return signalPassFailure(); |
80 | } |
81 | |
82 | namespace { |
83 | LogicalResult moduleSerializer(GPUModuleOp op, |
84 | OffloadingLLVMTranslationAttrInterface handler, |
85 | const TargetOptions &targetOptions) { |
86 | OpBuilder builder(op->getContext()); |
87 | SmallVector<Attribute> objects; |
88 | // Fail if there are no target attributes |
89 | if (!op.getTargetsAttr()) |
90 | return op.emitError("the module has no target attributes" ); |
91 | // Serialize all targets. |
92 | for (auto targetAttr : op.getTargetsAttr()) { |
93 | assert(targetAttr && "Target attribute cannot be null." ); |
94 | auto target = dyn_cast<gpu::TargetAttrInterface>(targetAttr); |
95 | assert(target && |
96 | "Target attribute doesn't implements `TargetAttrInterface`." ); |
97 | std::optional<SmallVector<char, 0>> serializedModule = |
98 | target.serializeToObject(op, targetOptions); |
99 | if (!serializedModule) { |
100 | op.emitError("An error happened while serializing the module." ); |
101 | return failure(); |
102 | } |
103 | |
104 | Attribute object = |
105 | target.createObject(op, *serializedModule, targetOptions); |
106 | if (!object) { |
107 | op.emitError("An error happened while creating the object." ); |
108 | return failure(); |
109 | } |
110 | objects.push_back(object); |
111 | } |
112 | if (auto moduleHandler = |
113 | dyn_cast_or_null<OffloadingLLVMTranslationAttrInterface>( |
114 | op.getOffloadingHandlerAttr()); |
115 | !handler && moduleHandler) |
116 | handler = moduleHandler; |
117 | builder.setInsertionPointAfter(op); |
118 | builder.create<gpu::BinaryOp>(op.getLoc(), op.getName(), handler, |
119 | builder.getArrayAttr(objects)); |
120 | op->erase(); |
121 | return success(); |
122 | } |
123 | } // namespace |
124 | |
125 | LogicalResult mlir::gpu::transformGpuModulesToBinaries( |
126 | Operation *op, OffloadingLLVMTranslationAttrInterface handler, |
127 | const gpu::TargetOptions &targetOptions) { |
128 | for (Region ®ion : op->getRegions()) |
129 | for (Block &block : region.getBlocks()) |
130 | for (auto module : |
131 | llvm::make_early_inc_range(block.getOps<GPUModuleOp>())) |
132 | if (failed(moduleSerializer(module, handler, targetOptions))) |
133 | return failure(); |
134 | return success(); |
135 | } |
136 | |