| 1 | //===- Target.cpp - MLIR SPIR-V target compilation --------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This files defines SPIR-V target related functions including registration |
| 10 | // calls for the `#spirv.target_env` compilation attribute. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Target/SPIRV/Target.h" |
| 15 | |
| 16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 17 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
| 18 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| 19 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 20 | #include "mlir/Target/SPIRV/Serialization.h" |
| 21 | |
| 22 | #include <cstdlib> |
| 23 | #include <cstring> |
| 24 | |
| 25 | using namespace mlir; |
| 26 | using namespace mlir::spirv; |
| 27 | |
| 28 | namespace { |
| 29 | // SPIR-V implementation of the gpu:TargetAttrInterface. |
| 30 | class SPIRVTargetAttrImpl |
| 31 | : public gpu::TargetAttrInterface::FallbackModel<SPIRVTargetAttrImpl> { |
| 32 | public: |
| 33 | std::optional<SmallVector<char, 0>> |
| 34 | serializeToObject(Attribute attribute, Operation *module, |
| 35 | const gpu::TargetOptions &options) const; |
| 36 | |
| 37 | Attribute createObject(Attribute attribute, Operation *module, |
| 38 | const SmallVector<char, 0> &object, |
| 39 | const gpu::TargetOptions &options) const; |
| 40 | }; |
| 41 | } // namespace |
| 42 | |
| 43 | // Register the SPIR-V dialect, the SPIR-V translation & the target interface. |
| 44 | void mlir::spirv::registerSPIRVTargetInterfaceExternalModels( |
| 45 | DialectRegistry ®istry) { |
| 46 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, spirv::SPIRVDialect *dialect) { |
| 47 | spirv::TargetEnvAttr::attachInterface<SPIRVTargetAttrImpl>(*ctx); |
| 48 | }); |
| 49 | } |
| 50 | |
| 51 | void mlir::spirv::registerSPIRVTargetInterfaceExternalModels( |
| 52 | MLIRContext &context) { |
| 53 | DialectRegistry registry; |
| 54 | registerSPIRVTargetInterfaceExternalModels(registry); |
| 55 | context.appendDialectRegistry(registry); |
| 56 | } |
| 57 | |
| 58 | // Reuse from existing serializer |
| 59 | std::optional<SmallVector<char, 0>> SPIRVTargetAttrImpl::serializeToObject( |
| 60 | Attribute attribute, Operation *module, |
| 61 | const gpu::TargetOptions &options) const { |
| 62 | if (!module) |
| 63 | return std::nullopt; |
| 64 | auto gpuMod = dyn_cast<gpu::GPUModuleOp>(module); |
| 65 | if (!gpuMod) { |
| 66 | module->emitError(message: "expected to be a gpu.module op" ); |
| 67 | return std::nullopt; |
| 68 | } |
| 69 | auto spvMods = gpuMod.getOps<spirv::ModuleOp>(); |
| 70 | if (spvMods.empty()) |
| 71 | return std::nullopt; |
| 72 | |
| 73 | auto spvMod = *spvMods.begin(); |
| 74 | llvm::SmallVector<uint32_t, 0> spvBinary; |
| 75 | |
| 76 | spvBinary.clear(); |
| 77 | // Serialize the spirv.module op to SPIR-V blob. |
| 78 | if (mlir::failed(Result: spirv::serialize(module: spvMod, binary&: spvBinary))) { |
| 79 | spvMod.emitError() << "failed to serialize SPIR-V module" ; |
| 80 | return std::nullopt; |
| 81 | } |
| 82 | |
| 83 | SmallVector<char, 0> spvData(spvBinary.size() * sizeof(uint32_t), 0); |
| 84 | std::memcpy(dest: spvData.data(), src: spvBinary.data(), n: spvData.size()); |
| 85 | |
| 86 | spvMod.erase(); |
| 87 | return spvData; |
| 88 | } |
| 89 | |
| 90 | // Prepare Attribute for gpu.binary with serialized kernel object |
| 91 | Attribute |
| 92 | SPIRVTargetAttrImpl::createObject(Attribute attribute, Operation *module, |
| 93 | const SmallVector<char, 0> &object, |
| 94 | const gpu::TargetOptions &options) const { |
| 95 | gpu::CompilationTarget format = options.getCompilationTarget(); |
| 96 | DictionaryAttr objectProps; |
| 97 | Builder builder(attribute.getContext()); |
| 98 | return builder.getAttr<gpu::ObjectAttr>( |
| 99 | attribute, format, |
| 100 | builder.getStringAttr(StringRef(object.data(), object.size())), |
| 101 | objectProps, /*kernels=*/nullptr); |
| 102 | } |
| 103 | |