1//===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation from SPIR-V binary module to MLIR SPIR-V
10// ModuleOp.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/BuiltinOps.h"
18#include "mlir/IR/Dialect.h"
19#include "mlir/IR/Verifier.h"
20#include "mlir/Parser/Parser.h"
21#include "mlir/Support/FileUtilities.h"
22#include "mlir/Target/SPIRV/Deserialization.h"
23#include "mlir/Target/SPIRV/Serialization.h"
24#include "mlir/Tools/mlir-translate/Translation.h"
25#include "llvm/ADT/StringRef.h"
26#include "llvm/Support/MemoryBuffer.h"
27#include "llvm/Support/SMLoc.h"
28#include "llvm/Support/SourceMgr.h"
29#include "llvm/Support/ToolOutputFile.h"
30
31using namespace mlir;
32
33//===----------------------------------------------------------------------===//
34// Deserialization registration
35//===----------------------------------------------------------------------===//
36
37// Deserializes the SPIR-V binary module stored in the file named as
38// `inputFilename` and returns a module containing the SPIR-V module.
39static OwningOpRef<Operation *>
40deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context,
41 const spirv::DeserializationOptions &options) {
42 context->loadDialect<spirv::SPIRVDialect>();
43
44 // Make sure the input stream can be treated as a stream of SPIR-V words
45 auto *start = input->getBufferStart();
46 auto size = input->getBufferSize();
47 if (size % sizeof(uint32_t) != 0) {
48 emitError(UnknownLoc::get(context))
49 << "SPIR-V binary module must contain integral number of 32-bit words";
50 return {};
51 }
52
53 auto binary = llvm::ArrayRef(reinterpret_cast<const uint32_t *>(start),
54 size / sizeof(uint32_t));
55 return spirv::deserialize(binary, context, options);
56}
57
58namespace mlir {
59void registerFromSPIRVTranslation() {
60 static llvm::cl::opt<bool> enableControlFlowStructurization(
61 "spirv-structurize-control-flow",
62 llvm::cl::desc(
63 "Enable control flow structurization into `spirv.mlir.selection` and "
64 "`spirv.mlir.loop`. This may need to be disabled to support "
65 "deserialization of early exits (see #138688)"),
66 llvm::cl::init(Val: true));
67
68 TranslateToMLIRRegistration fromBinary(
69 "deserialize-spirv", "deserializes the SPIR-V module",
70 [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
71 assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
72 return deserializeModule(
73 input: sourceMgr.getMemoryBuffer(i: sourceMgr.getMainFileID()), context,
74 options: {.enableControlFlowStructurization: enableControlFlowStructurization});
75 });
76}
77} // namespace mlir
78
79//===----------------------------------------------------------------------===//
80// Serialization registration
81//===----------------------------------------------------------------------===//
82
83static LogicalResult serializeModule(spirv::ModuleOp module,
84 raw_ostream &output) {
85 SmallVector<uint32_t, 0> binary;
86 if (failed(spirv::serialize(module: module, binary)))
87 return failure();
88
89 output.write(Ptr: reinterpret_cast<char *>(binary.data()),
90 Size: binary.size() * sizeof(uint32_t));
91
92 return mlir::success();
93}
94
95namespace mlir {
96void registerToSPIRVTranslation() {
97 TranslateFromMLIRRegistration toBinary(
98 "serialize-spirv", "serialize SPIR-V dialect",
99 [](spirv::ModuleOp module, raw_ostream &output) {
100 return serializeModule(module, output);
101 },
102 [](DialectRegistry &registry) {
103 registry.insert<spirv::SPIRVDialect>();
104 });
105}
106} // namespace mlir
107
108//===----------------------------------------------------------------------===//
109// Round-trip registration
110//===----------------------------------------------------------------------===//
111
112static LogicalResult roundTripModule(spirv::ModuleOp module, bool emitDebugInfo,
113 raw_ostream &output) {
114 SmallVector<uint32_t, 0> binary;
115 MLIRContext *context = module->getContext();
116
117 spirv::SerializationOptions options;
118 options.emitDebugInfo = emitDebugInfo;
119 if (failed(spirv::serialize(module: module, binary, options)))
120 return failure();
121
122 MLIRContext deserializationContext(context->getDialectRegistry());
123 // TODO: we should only load the required dialects instead of all dialects.
124 deserializationContext.loadAllAvailableDialects();
125 // Then deserialize to get back a SPIR-V module.
126 OwningOpRef<spirv::ModuleOp> spirvModule =
127 spirv::deserialize(binary, context: &deserializationContext);
128 if (!spirvModule)
129 return failure();
130 spirvModule->print(output);
131
132 return mlir::success();
133}
134
135namespace mlir {
136void registerTestRoundtripSPIRV() {
137 TranslateFromMLIRRegistration roundtrip(
138 "test-spirv-roundtrip", "test roundtrip in SPIR-V dialect",
139 [](spirv::ModuleOp module, raw_ostream &output) {
140 return roundTripModule(module, /*emitDebugInfo=*/false, output);
141 },
142 [](DialectRegistry &registry) {
143 registry.insert<spirv::SPIRVDialect>();
144 });
145}
146
147void registerTestRoundtripDebugSPIRV() {
148 TranslateFromMLIRRegistration roundtrip(
149 "test-spirv-roundtrip-debug", "test roundtrip debug in SPIR-V",
150 [](spirv::ModuleOp module, raw_ostream &output) {
151 return roundTripModule(module, /*emitDebugInfo=*/true, output);
152 },
153 [](DialectRegistry &registry) {
154 registry.insert<spirv::SPIRVDialect>();
155 });
156}
157} // namespace mlir
158

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Target/SPIRV/TranslateRegistration.cpp