1 | //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a translation from SPIR-V binary module to MLIR SPIR-V |
10 | // ModuleOp. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
16 | #include "mlir/IR/Builders.h" |
17 | #include "mlir/IR/BuiltinOps.h" |
18 | #include "mlir/IR/Dialect.h" |
19 | #include "mlir/IR/Verifier.h" |
20 | #include "mlir/Parser/Parser.h" |
21 | #include "mlir/Support/FileUtilities.h" |
22 | #include "mlir/Target/SPIRV/Deserialization.h" |
23 | #include "mlir/Target/SPIRV/Serialization.h" |
24 | #include "mlir/Tools/mlir-translate/Translation.h" |
25 | #include "llvm/ADT/StringRef.h" |
26 | #include "llvm/Support/MemoryBuffer.h" |
27 | #include "llvm/Support/SMLoc.h" |
28 | #include "llvm/Support/SourceMgr.h" |
29 | #include "llvm/Support/ToolOutputFile.h" |
30 | |
31 | using namespace mlir; |
32 | |
33 | //===----------------------------------------------------------------------===// |
34 | // Deserialization registration |
35 | //===----------------------------------------------------------------------===// |
36 | |
37 | // Deserializes the SPIR-V binary module stored in the file named as |
38 | // `inputFilename` and returns a module containing the SPIR-V module. |
39 | static OwningOpRef<Operation *> |
40 | deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context) { |
41 | context->loadDialect<spirv::SPIRVDialect>(); |
42 | |
43 | // Make sure the input stream can be treated as a stream of SPIR-V words |
44 | auto *start = input->getBufferStart(); |
45 | auto size = input->getBufferSize(); |
46 | if (size % sizeof(uint32_t) != 0) { |
47 | emitError(UnknownLoc::get(context)) |
48 | << "SPIR-V binary module must contain integral number of 32-bit words" ; |
49 | return {}; |
50 | } |
51 | |
52 | auto binary = llvm::ArrayRef(reinterpret_cast<const uint32_t *>(start), |
53 | size / sizeof(uint32_t)); |
54 | return spirv::deserialize(binary, context); |
55 | } |
56 | |
57 | namespace mlir { |
58 | void registerFromSPIRVTranslation() { |
59 | TranslateToMLIRRegistration fromBinary( |
60 | "deserialize-spirv" , "deserializes the SPIR-V module" , |
61 | [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { |
62 | assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer" ); |
63 | return deserializeModule( |
64 | input: sourceMgr.getMemoryBuffer(i: sourceMgr.getMainFileID()), context); |
65 | }); |
66 | } |
67 | } // namespace mlir |
68 | |
69 | //===----------------------------------------------------------------------===// |
70 | // Serialization registration |
71 | //===----------------------------------------------------------------------===// |
72 | |
73 | static LogicalResult serializeModule(spirv::ModuleOp module, |
74 | raw_ostream &output) { |
75 | SmallVector<uint32_t, 0> binary; |
76 | if (failed(spirv::serialize(module: module, binary))) |
77 | return failure(); |
78 | |
79 | output.write(Ptr: reinterpret_cast<char *>(binary.data()), |
80 | Size: binary.size() * sizeof(uint32_t)); |
81 | |
82 | return mlir::success(); |
83 | } |
84 | |
85 | namespace mlir { |
86 | void registerToSPIRVTranslation() { |
87 | TranslateFromMLIRRegistration toBinary( |
88 | "serialize-spirv" , "serialize SPIR-V dialect" , |
89 | [](spirv::ModuleOp module, raw_ostream &output) { |
90 | return serializeModule(module, output); |
91 | }, |
92 | [](DialectRegistry ®istry) { |
93 | registry.insert<spirv::SPIRVDialect>(); |
94 | }); |
95 | } |
96 | } // namespace mlir |
97 | |
98 | //===----------------------------------------------------------------------===// |
99 | // Round-trip registration |
100 | //===----------------------------------------------------------------------===// |
101 | |
102 | static LogicalResult roundTripModule(spirv::ModuleOp module, bool emitDebugInfo, |
103 | raw_ostream &output) { |
104 | SmallVector<uint32_t, 0> binary; |
105 | MLIRContext *context = module->getContext(); |
106 | |
107 | spirv::SerializationOptions options; |
108 | options.emitDebugInfo = emitDebugInfo; |
109 | if (failed(spirv::serialize(module: module, binary, options))) |
110 | return failure(); |
111 | |
112 | MLIRContext deserializationContext(context->getDialectRegistry()); |
113 | // TODO: we should only load the required dialects instead of all dialects. |
114 | deserializationContext.loadAllAvailableDialects(); |
115 | // Then deserialize to get back a SPIR-V module. |
116 | OwningOpRef<spirv::ModuleOp> spirvModule = |
117 | spirv::deserialize(binary, context: &deserializationContext); |
118 | if (!spirvModule) |
119 | return failure(); |
120 | spirvModule->print(output); |
121 | |
122 | return mlir::success(); |
123 | } |
124 | |
125 | namespace mlir { |
126 | void registerTestRoundtripSPIRV() { |
127 | TranslateFromMLIRRegistration roundtrip( |
128 | "test-spirv-roundtrip" , "test roundtrip in SPIR-V dialect" , |
129 | [](spirv::ModuleOp module, raw_ostream &output) { |
130 | return roundTripModule(module, /*emitDebugInfo=*/false, output); |
131 | }, |
132 | [](DialectRegistry ®istry) { |
133 | registry.insert<spirv::SPIRVDialect>(); |
134 | }); |
135 | } |
136 | |
137 | void registerTestRoundtripDebugSPIRV() { |
138 | TranslateFromMLIRRegistration roundtrip( |
139 | "test-spirv-roundtrip-debug" , "test roundtrip debug in SPIR-V" , |
140 | [](spirv::ModuleOp module, raw_ostream &output) { |
141 | return roundTripModule(module, /*emitDebugInfo=*/true, output); |
142 | }, |
143 | [](DialectRegistry ®istry) { |
144 | registry.insert<spirv::SPIRVDialect>(); |
145 | }); |
146 | } |
147 | } // namespace mlir |
148 | |