| 1 | //===- SPIRVAttachTarget.cpp - Attach an SPIR-V target --------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the `GPUSPIRVAttachTarget` pass, attaching |
| 10 | // `#spirv.target_env` attributes to GPU modules. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
| 15 | |
| 16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 17 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
| 18 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| 19 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
| 20 | #include "mlir/IR/Builders.h" |
| 21 | #include "mlir/Pass/Pass.h" |
| 22 | #include "mlir/Target/SPIRV/Target.h" |
| 23 | #include "llvm/Support/Regex.h" |
| 24 | |
| 25 | namespace mlir { |
| 26 | #define GEN_PASS_DEF_GPUSPIRVATTACHTARGET |
| 27 | #include "mlir/Dialect/GPU/Transforms/Passes.h.inc" |
| 28 | } // namespace mlir |
| 29 | |
| 30 | using namespace mlir; |
| 31 | using namespace mlir::spirv; |
| 32 | |
| 33 | namespace { |
| 34 | struct SPIRVAttachTarget |
| 35 | : public impl::GpuSPIRVAttachTargetBase<SPIRVAttachTarget> { |
| 36 | using Base::Base; |
| 37 | |
| 38 | void runOnOperation() override; |
| 39 | |
| 40 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 41 | registry.insert<spirv::SPIRVDialect>(); |
| 42 | } |
| 43 | }; |
| 44 | } // namespace |
| 45 | |
| 46 | void SPIRVAttachTarget::runOnOperation() { |
| 47 | OpBuilder builder(&getContext()); |
| 48 | auto versionSymbol = symbolizeVersion(spirvVersion); |
| 49 | if (!versionSymbol) |
| 50 | return signalPassFailure(); |
| 51 | auto apiSymbol = symbolizeClientAPI(clientApi); |
| 52 | if (!apiSymbol) |
| 53 | return signalPassFailure(); |
| 54 | auto vendorSymbol = symbolizeVendor(deviceVendor); |
| 55 | if (!vendorSymbol) |
| 56 | return signalPassFailure(); |
| 57 | auto deviceTypeSymbol = symbolizeDeviceType(deviceType); |
| 58 | if (!deviceTypeSymbol) |
| 59 | return signalPassFailure(); |
| 60 | // Set the default device ID if none was given |
| 61 | if (!deviceId.hasValue()) |
| 62 | deviceId = mlir::spirv::TargetEnvAttr::kUnknownDeviceID; |
| 63 | |
| 64 | Version version = versionSymbol.value(); |
| 65 | SmallVector<Capability, 4> capabilities; |
| 66 | SmallVector<Extension, 8> extensions; |
| 67 | for (const auto &cap : spirvCapabilities) { |
| 68 | auto capSymbol = symbolizeCapability(cap); |
| 69 | if (capSymbol) |
| 70 | capabilities.push_back(capSymbol.value()); |
| 71 | } |
| 72 | ArrayRef<Capability> caps(capabilities); |
| 73 | for (const auto &ext : spirvExtensions) { |
| 74 | auto extSymbol = symbolizeExtension(ext); |
| 75 | if (extSymbol) |
| 76 | extensions.push_back(extSymbol.value()); |
| 77 | } |
| 78 | ArrayRef<Extension> exts(extensions); |
| 79 | VerCapExtAttr vce = VerCapExtAttr::get(version, caps, exts, &getContext()); |
| 80 | auto target = TargetEnvAttr::get(vce, getDefaultResourceLimits(&getContext()), |
| 81 | apiSymbol.value(), vendorSymbol.value(), |
| 82 | deviceTypeSymbol.value(), deviceId); |
| 83 | llvm::Regex matcher(moduleMatcher); |
| 84 | getOperation()->walk([&](gpu::GPUModuleOp gpuModule) { |
| 85 | // Check if the name of the module matches. |
| 86 | if (!moduleMatcher.empty() && !matcher.match(gpuModule.getName())) |
| 87 | return; |
| 88 | // Create the target array. |
| 89 | SmallVector<Attribute> targets; |
| 90 | if (std::optional<ArrayAttr> attrs = gpuModule.getTargets()) |
| 91 | targets.append(attrs->getValue().begin(), attrs->getValue().end()); |
| 92 | targets.push_back(Elt: target); |
| 93 | // Remove any duplicate targets. |
| 94 | targets.erase(CS: llvm::unique(R&: targets), CE: targets.end()); |
| 95 | // Update the target attribute array. |
| 96 | gpuModule.setTargetsAttr(builder.getArrayAttr(targets)); |
| 97 | }); |
| 98 | } |
| 99 | |