1//===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation from SPIR-V binary module to MLIR SPIR-V
10// ModuleOp.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/Verifier.h"
18#include "mlir/Parser/Parser.h"
19#include "mlir/Support/FileUtilities.h"
20#include "mlir/Target/SPIRV/Deserialization.h"
21#include "mlir/Target/SPIRV/Serialization.h"
22#include "mlir/Tools/mlir-translate/Translation.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/Support/MemoryBuffer.h"
25#include "llvm/Support/SourceMgr.h"
26
27using namespace mlir;
28
29//===----------------------------------------------------------------------===//
30// Deserialization registration
31//===----------------------------------------------------------------------===//
32
33// Deserializes the SPIR-V binary module stored in the file named as
34// `inputFilename` and returns a module containing the SPIR-V module.
35static OwningOpRef<Operation *>
36deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context,
37 const spirv::DeserializationOptions &options) {
38 context->loadDialect<spirv::SPIRVDialect>();
39
40 // Make sure the input stream can be treated as a stream of SPIR-V words
41 auto *start = input->getBufferStart();
42 auto size = input->getBufferSize();
43 if (size % sizeof(uint32_t) != 0) {
44 emitError(loc: UnknownLoc::get(context))
45 << "SPIR-V binary module must contain integral number of 32-bit words";
46 return {};
47 }
48
49 auto binary = llvm::ArrayRef(reinterpret_cast<const uint32_t *>(start),
50 size / sizeof(uint32_t));
51 return spirv::deserialize(binary, context, options);
52}
53
54namespace mlir {
55void registerFromSPIRVTranslation() {
56 static llvm::cl::opt<bool> enableControlFlowStructurization(
57 "spirv-structurize-control-flow",
58 llvm::cl::desc(
59 "Enable control flow structurization into `spirv.mlir.selection` and "
60 "`spirv.mlir.loop`. This may need to be disabled to support "
61 "deserialization of early exits (see #138688)"),
62 llvm::cl::init(Val: true));
63
64 TranslateToMLIRRegistration fromBinary(
65 "deserialize-spirv", "deserializes the SPIR-V module",
66 [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
67 assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
68 return deserializeModule(
69 input: sourceMgr.getMemoryBuffer(i: sourceMgr.getMainFileID()), context,
70 options: {.enableControlFlowStructurization: enableControlFlowStructurization});
71 });
72}
73} // namespace mlir
74
75//===----------------------------------------------------------------------===//
76// Serialization registration
77//===----------------------------------------------------------------------===//
78
79static LogicalResult serializeModule(spirv::ModuleOp module,
80 raw_ostream &output) {
81 SmallVector<uint32_t, 0> binary;
82 if (failed(Result: spirv::serialize(module, binary)))
83 return failure();
84
85 output.write(Ptr: reinterpret_cast<char *>(binary.data()),
86 Size: binary.size() * sizeof(uint32_t));
87
88 return mlir::success();
89}
90
91namespace mlir {
92void registerToSPIRVTranslation() {
93 TranslateFromMLIRRegistration toBinary(
94 "serialize-spirv", "serialize SPIR-V dialect",
95 [](spirv::ModuleOp module, raw_ostream &output) {
96 return serializeModule(module, output);
97 },
98 [](DialectRegistry &registry) {
99 registry.insert<spirv::SPIRVDialect>();
100 });
101}
102} // namespace mlir
103
104//===----------------------------------------------------------------------===//
105// Round-trip registration
106//===----------------------------------------------------------------------===//
107
108static LogicalResult roundTripModule(spirv::ModuleOp module, bool emitDebugInfo,
109 raw_ostream &output) {
110 SmallVector<uint32_t, 0> binary;
111 MLIRContext *context = module->getContext();
112
113 spirv::SerializationOptions options;
114 options.emitDebugInfo = emitDebugInfo;
115 if (failed(Result: spirv::serialize(module, binary, options)))
116 return failure();
117
118 MLIRContext deserializationContext(context->getDialectRegistry());
119 // TODO: we should only load the required dialects instead of all dialects.
120 deserializationContext.loadAllAvailableDialects();
121 // Then deserialize to get back a SPIR-V module.
122 OwningOpRef<spirv::ModuleOp> spirvModule =
123 spirv::deserialize(binary, context: &deserializationContext);
124 if (!spirvModule)
125 return failure();
126 spirvModule->print(os&: output);
127
128 return mlir::success();
129}
130
131namespace mlir {
132void registerTestRoundtripSPIRV() {
133 TranslateFromMLIRRegistration roundtrip(
134 "test-spirv-roundtrip", "test roundtrip in SPIR-V dialect",
135 [](spirv::ModuleOp module, raw_ostream &output) {
136 return roundTripModule(module, /*emitDebugInfo=*/false, output);
137 },
138 [](DialectRegistry &registry) {
139 registry.insert<spirv::SPIRVDialect>();
140 });
141}
142
143void registerTestRoundtripDebugSPIRV() {
144 TranslateFromMLIRRegistration roundtrip(
145 "test-spirv-roundtrip-debug", "test roundtrip debug in SPIR-V",
146 [](spirv::ModuleOp module, raw_ostream &output) {
147 return roundTripModule(module, /*emitDebugInfo=*/true, output);
148 },
149 [](DialectRegistry &registry) {
150 registry.insert<spirv::SPIRVDialect>();
151 });
152}
153} // namespace mlir
154

source code of mlir/lib/Target/SPIRV/TranslateRegistration.cpp