| 1 | //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements a translation from SPIR-V binary module to MLIR SPIR-V |
| 10 | // ModuleOp. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| 15 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 16 | #include "mlir/IR/Builders.h" |
| 17 | #include "mlir/IR/BuiltinOps.h" |
| 18 | #include "mlir/IR/Dialect.h" |
| 19 | #include "mlir/IR/Verifier.h" |
| 20 | #include "mlir/Parser/Parser.h" |
| 21 | #include "mlir/Support/FileUtilities.h" |
| 22 | #include "mlir/Target/SPIRV/Deserialization.h" |
| 23 | #include "mlir/Target/SPIRV/Serialization.h" |
| 24 | #include "mlir/Tools/mlir-translate/Translation.h" |
| 25 | #include "llvm/ADT/StringRef.h" |
| 26 | #include "llvm/Support/MemoryBuffer.h" |
| 27 | #include "llvm/Support/SMLoc.h" |
| 28 | #include "llvm/Support/SourceMgr.h" |
| 29 | #include "llvm/Support/ToolOutputFile.h" |
| 30 | |
| 31 | using namespace mlir; |
| 32 | |
| 33 | //===----------------------------------------------------------------------===// |
| 34 | // Deserialization registration |
| 35 | //===----------------------------------------------------------------------===// |
| 36 | |
| 37 | // Deserializes the SPIR-V binary module stored in the file named as |
| 38 | // `inputFilename` and returns a module containing the SPIR-V module. |
| 39 | static OwningOpRef<Operation *> |
| 40 | deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context, |
| 41 | const spirv::DeserializationOptions &options) { |
| 42 | context->loadDialect<spirv::SPIRVDialect>(); |
| 43 | |
| 44 | // Make sure the input stream can be treated as a stream of SPIR-V words |
| 45 | auto *start = input->getBufferStart(); |
| 46 | auto size = input->getBufferSize(); |
| 47 | if (size % sizeof(uint32_t) != 0) { |
| 48 | emitError(UnknownLoc::get(context)) |
| 49 | << "SPIR-V binary module must contain integral number of 32-bit words" ; |
| 50 | return {}; |
| 51 | } |
| 52 | |
| 53 | auto binary = llvm::ArrayRef(reinterpret_cast<const uint32_t *>(start), |
| 54 | size / sizeof(uint32_t)); |
| 55 | return spirv::deserialize(binary, context, options); |
| 56 | } |
| 57 | |
| 58 | namespace mlir { |
| 59 | void registerFromSPIRVTranslation() { |
| 60 | static llvm::cl::opt<bool> enableControlFlowStructurization( |
| 61 | "spirv-structurize-control-flow" , |
| 62 | llvm::cl::desc( |
| 63 | "Enable control flow structurization into `spirv.mlir.selection` and " |
| 64 | "`spirv.mlir.loop`. This may need to be disabled to support " |
| 65 | "deserialization of early exits (see #138688)" ), |
| 66 | llvm::cl::init(Val: true)); |
| 67 | |
| 68 | TranslateToMLIRRegistration fromBinary( |
| 69 | "deserialize-spirv" , "deserializes the SPIR-V module" , |
| 70 | [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { |
| 71 | assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer" ); |
| 72 | return deserializeModule( |
| 73 | input: sourceMgr.getMemoryBuffer(i: sourceMgr.getMainFileID()), context, |
| 74 | options: {.enableControlFlowStructurization: enableControlFlowStructurization}); |
| 75 | }); |
| 76 | } |
| 77 | } // namespace mlir |
| 78 | |
| 79 | //===----------------------------------------------------------------------===// |
| 80 | // Serialization registration |
| 81 | //===----------------------------------------------------------------------===// |
| 82 | |
| 83 | static LogicalResult serializeModule(spirv::ModuleOp module, |
| 84 | raw_ostream &output) { |
| 85 | SmallVector<uint32_t, 0> binary; |
| 86 | if (failed(spirv::serialize(module: module, binary))) |
| 87 | return failure(); |
| 88 | |
| 89 | output.write(Ptr: reinterpret_cast<char *>(binary.data()), |
| 90 | Size: binary.size() * sizeof(uint32_t)); |
| 91 | |
| 92 | return mlir::success(); |
| 93 | } |
| 94 | |
| 95 | namespace mlir { |
| 96 | void registerToSPIRVTranslation() { |
| 97 | TranslateFromMLIRRegistration toBinary( |
| 98 | "serialize-spirv" , "serialize SPIR-V dialect" , |
| 99 | [](spirv::ModuleOp module, raw_ostream &output) { |
| 100 | return serializeModule(module, output); |
| 101 | }, |
| 102 | [](DialectRegistry ®istry) { |
| 103 | registry.insert<spirv::SPIRVDialect>(); |
| 104 | }); |
| 105 | } |
| 106 | } // namespace mlir |
| 107 | |
| 108 | //===----------------------------------------------------------------------===// |
| 109 | // Round-trip registration |
| 110 | //===----------------------------------------------------------------------===// |
| 111 | |
| 112 | static LogicalResult roundTripModule(spirv::ModuleOp module, bool emitDebugInfo, |
| 113 | raw_ostream &output) { |
| 114 | SmallVector<uint32_t, 0> binary; |
| 115 | MLIRContext *context = module->getContext(); |
| 116 | |
| 117 | spirv::SerializationOptions options; |
| 118 | options.emitDebugInfo = emitDebugInfo; |
| 119 | if (failed(spirv::serialize(module: module, binary, options))) |
| 120 | return failure(); |
| 121 | |
| 122 | MLIRContext deserializationContext(context->getDialectRegistry()); |
| 123 | // TODO: we should only load the required dialects instead of all dialects. |
| 124 | deserializationContext.loadAllAvailableDialects(); |
| 125 | // Then deserialize to get back a SPIR-V module. |
| 126 | OwningOpRef<spirv::ModuleOp> spirvModule = |
| 127 | spirv::deserialize(binary, context: &deserializationContext); |
| 128 | if (!spirvModule) |
| 129 | return failure(); |
| 130 | spirvModule->print(output); |
| 131 | |
| 132 | return mlir::success(); |
| 133 | } |
| 134 | |
| 135 | namespace mlir { |
| 136 | void registerTestRoundtripSPIRV() { |
| 137 | TranslateFromMLIRRegistration roundtrip( |
| 138 | "test-spirv-roundtrip" , "test roundtrip in SPIR-V dialect" , |
| 139 | [](spirv::ModuleOp module, raw_ostream &output) { |
| 140 | return roundTripModule(module, /*emitDebugInfo=*/false, output); |
| 141 | }, |
| 142 | [](DialectRegistry ®istry) { |
| 143 | registry.insert<spirv::SPIRVDialect>(); |
| 144 | }); |
| 145 | } |
| 146 | |
| 147 | void registerTestRoundtripDebugSPIRV() { |
| 148 | TranslateFromMLIRRegistration roundtrip( |
| 149 | "test-spirv-roundtrip-debug" , "test roundtrip debug in SPIR-V" , |
| 150 | [](spirv::ModuleOp module, raw_ostream &output) { |
| 151 | return roundTripModule(module, /*emitDebugInfo=*/true, output); |
| 152 | }, |
| 153 | [](DialectRegistry ®istry) { |
| 154 | registry.insert<spirv::SPIRVDialect>(); |
| 155 | }); |
| 156 | } |
| 157 | } // namespace mlir |
| 158 | |