1//===- ModuleToBinary.cpp - Transforms GPU modules to GPU binaries ----------=//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the `GpuModuleToBinaryPass` pass, transforming GPU
10// modules into GPU binaries.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/GPU/Transforms/Passes.h"
15
16#include "mlir/Config/mlir-config.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
21#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
22#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
23#include "mlir/IR/BuiltinOps.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/StringSwitch.h"
28
29using namespace mlir;
30using namespace mlir::gpu;
31
32namespace mlir {
33#define GEN_PASS_DEF_GPUMODULETOBINARYPASS
34#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
35} // namespace mlir
36
37namespace {
38class GpuModuleToBinaryPass
39 : public impl::GpuModuleToBinaryPassBase<GpuModuleToBinaryPass> {
40public:
41 using Base::Base;
42 void getDependentDialects(DialectRegistry &registry) const override;
43 void runOnOperation() final;
44};
45} // namespace
46
47void GpuModuleToBinaryPass::getDependentDialects(
48 DialectRegistry &registry) const {
49 // Register all GPU related translations.
50 registry.insert<gpu::GPUDialect>();
51 registry.insert<LLVM::LLVMDialect>();
52#if MLIR_ENABLE_CUDA_CONVERSIONS
53 registry.insert<NVVM::NVVMDialect>();
54#endif
55#if MLIR_ENABLE_ROCM_CONVERSIONS
56 registry.insert<ROCDL::ROCDLDialect>();
57#endif
58 registry.insert<spirv::SPIRVDialect>();
59}
60
61void GpuModuleToBinaryPass::runOnOperation() {
62 RewritePatternSet patterns(&getContext());
63 auto targetFormat =
64 llvm::StringSwitch<std::optional<CompilationTarget>>(compilationTarget)
65 .Cases("offloading", "llvm", CompilationTarget::Offload)
66 .Cases("assembly", "isa", CompilationTarget::Assembly)
67 .Cases("binary", "bin", CompilationTarget::Binary)
68 .Cases("fatbinary", "fatbin", CompilationTarget::Fatbin)
69 .Default(std::nullopt);
70 if (!targetFormat)
71 getOperation()->emitError() << "Invalid format specified.";
72
73 // Lazy symbol table builder callback.
74 std::optional<SymbolTable> parentTable;
75 auto lazyTableBuilder = [&]() -> SymbolTable * {
76 // Build the table if it has not been built.
77 if (!parentTable) {
78 Operation *table = SymbolTable::getNearestSymbolTable(from: getOperation());
79 // It's up to the target attribute to determine if failing to find a
80 // symbol table is an error.
81 if (!table)
82 return nullptr;
83 parentTable = SymbolTable(table);
84 }
85 return &parentTable.value();
86 };
87
88 TargetOptions targetOptions(toolkitPath, linkFiles, cmdOptions, *targetFormat,
89 lazyTableBuilder);
90 if (failed(transformGpuModulesToBinaries(
91 getOperation(),
92 offloadingHandler ? dyn_cast<OffloadingLLVMTranslationAttrInterface>(
93 offloadingHandler.getValue())
94 : OffloadingLLVMTranslationAttrInterface(nullptr),
95 targetOptions)))
96 return signalPassFailure();
97}
98
99namespace {
100LogicalResult moduleSerializer(GPUModuleOp op,
101 OffloadingLLVMTranslationAttrInterface handler,
102 const TargetOptions &targetOptions) {
103 OpBuilder builder(op->getContext());
104 SmallVector<Attribute> objects;
105 // Fail if there are no target attributes
106 if (!op.getTargetsAttr())
107 return op.emitError("the module has no target attributes");
108 // Serialize all targets.
109 for (auto targetAttr : op.getTargetsAttr()) {
110 assert(targetAttr && "Target attribute cannot be null.");
111 auto target = dyn_cast<gpu::TargetAttrInterface>(targetAttr);
112 assert(target &&
113 "Target attribute doesn't implements `TargetAttrInterface`.");
114 std::optional<SmallVector<char, 0>> serializedModule =
115 target.serializeToObject(op, targetOptions);
116 if (!serializedModule) {
117 op.emitError("An error happened while serializing the module.");
118 return failure();
119 }
120
121 Attribute object = target.createObject(*serializedModule, targetOptions);
122 if (!object) {
123 op.emitError("An error happened while creating the object.");
124 return failure();
125 }
126 objects.push_back(object);
127 }
128 if (auto moduleHandler =
129 dyn_cast_or_null<OffloadingLLVMTranslationAttrInterface>(
130 op.getOffloadingHandlerAttr());
131 !handler && moduleHandler)
132 handler = moduleHandler;
133 builder.setInsertionPointAfter(op);
134 builder.create<gpu::BinaryOp>(op.getLoc(), op.getName(), handler,
135 builder.getArrayAttr(objects));
136 op->erase();
137 return success();
138}
139} // namespace
140
141LogicalResult mlir::gpu::transformGpuModulesToBinaries(
142 Operation *op, OffloadingLLVMTranslationAttrInterface handler,
143 const gpu::TargetOptions &targetOptions) {
144 for (Region &region : op->getRegions())
145 for (Block &block : region.getBlocks())
146 for (auto module :
147 llvm::make_early_inc_range(block.getOps<GPUModuleOp>()))
148 if (failed(moduleSerializer(module, handler, targetOptions)))
149 return failure();
150 return success();
151}
152

source code of mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp