1//===- ModuleToBinary.cpp - Transforms GPU modules to GPU binaries ----------=//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the `GpuModuleToBinaryPass` pass, transforming GPU
10// modules into GPU binaries.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/GPU/Transforms/Passes.h"
15
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/StringSwitch.h"
20
21using namespace mlir;
22using namespace mlir::gpu;
23
24namespace mlir {
25#define GEN_PASS_DEF_GPUMODULETOBINARYPASS
26#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
27} // namespace mlir
28
29namespace {
30class GpuModuleToBinaryPass
31 : public impl::GpuModuleToBinaryPassBase<GpuModuleToBinaryPass> {
32public:
33 using Base::Base;
34 void runOnOperation() final;
35};
36} // namespace
37
38void GpuModuleToBinaryPass::runOnOperation() {
39 RewritePatternSet patterns(&getContext());
40 auto targetFormat =
41 llvm::StringSwitch<std::optional<CompilationTarget>>(compilationTarget)
42 .Cases(S0: "offloading", S1: "llvm", Value: CompilationTarget::Offload)
43 .Cases(S0: "assembly", S1: "isa", Value: CompilationTarget::Assembly)
44 .Cases(S0: "binary", S1: "bin", Value: CompilationTarget::Binary)
45 .Cases(S0: "fatbinary", S1: "fatbin", Value: CompilationTarget::Fatbin)
46 .Default(Value: std::nullopt);
47 if (!targetFormat)
48 getOperation()->emitError() << "Invalid format specified.";
49
50 // Lazy symbol table builder callback.
51 std::optional<SymbolTable> parentTable;
52 auto lazyTableBuilder = [&]() -> SymbolTable * {
53 // Build the table if it has not been built.
54 if (!parentTable) {
55 Operation *table = SymbolTable::getNearestSymbolTable(from: getOperation());
56 // It's up to the target attribute to determine if failing to find a
57 // symbol table is an error.
58 if (!table)
59 return nullptr;
60 parentTable = SymbolTable(table);
61 }
62 return &parentTable.value();
63 };
64 SmallVector<Attribute> librariesToLink;
65 for (const std::string &path : linkFiles)
66 librariesToLink.push_back(Elt: StringAttr::get(context: &getContext(), bytes: path));
67 TargetOptions targetOptions(toolkitPath, librariesToLink, cmdOptions,
68 elfSection, *targetFormat, lazyTableBuilder);
69 if (failed(Result: transformGpuModulesToBinaries(
70 op: getOperation(), handler: OffloadingLLVMTranslationAttrInterface(nullptr),
71 options: targetOptions)))
72 return signalPassFailure();
73}
74
75namespace {
76LogicalResult moduleSerializer(GPUModuleOp op,
77 OffloadingLLVMTranslationAttrInterface handler,
78 const TargetOptions &targetOptions) {
79 OpBuilder builder(op->getContext());
80 SmallVector<Attribute> objects;
81 // Fail if there are no target attributes
82 if (!op.getTargetsAttr())
83 return op.emitError(message: "the module has no target attributes");
84 // Serialize all targets.
85 for (auto targetAttr : op.getTargetsAttr()) {
86 assert(targetAttr && "Target attribute cannot be null.");
87 auto target = dyn_cast<gpu::TargetAttrInterface>(Val&: targetAttr);
88 assert(target &&
89 "Target attribute doesn't implements `TargetAttrInterface`.");
90 std::optional<SmallVector<char, 0>> serializedModule =
91 target.serializeToObject(module: op, options: targetOptions);
92 if (!serializedModule) {
93 op.emitError(message: "An error happened while serializing the module.");
94 return failure();
95 }
96
97 Attribute object =
98 target.createObject(module: op, object: *serializedModule, options: targetOptions);
99 if (!object) {
100 op.emitError(message: "An error happened while creating the object.");
101 return failure();
102 }
103 objects.push_back(Elt: object);
104 }
105 if (auto moduleHandler =
106 dyn_cast_or_null<OffloadingLLVMTranslationAttrInterface>(
107 Val: op.getOffloadingHandlerAttr());
108 !handler && moduleHandler)
109 handler = moduleHandler;
110 builder.setInsertionPointAfter(op);
111 builder.create<gpu::BinaryOp>(location: op.getLoc(), args: op.getName(), args&: handler,
112 args: builder.getArrayAttr(value: objects));
113 op->erase();
114 return success();
115}
116} // namespace
117
118LogicalResult mlir::gpu::transformGpuModulesToBinaries(
119 Operation *op, OffloadingLLVMTranslationAttrInterface handler,
120 const gpu::TargetOptions &targetOptions) {
121 for (Region &region : op->getRegions())
122 for (Block &block : region.getBlocks())
123 for (auto module :
124 llvm::make_early_inc_range(Range: block.getOps<GPUModuleOp>()))
125 if (failed(Result: moduleSerializer(op: module, handler, targetOptions)))
126 return failure();
127 return success();
128}
129

source code of mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp