1 | //===- SPIRVAttachTarget.cpp - Attach an SPIR-V target --------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the `GPUSPIRVAttachTarget` pass, attaching |
10 | // `#spirv.target_env` attributes to GPU modules. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
15 | |
16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
19 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
20 | #include "mlir/IR/Builders.h" |
21 | #include "mlir/Pass/Pass.h" |
22 | #include "mlir/Target/SPIRV/Target.h" |
23 | #include "llvm/Support/Regex.h" |
24 | |
25 | namespace mlir { |
26 | #define GEN_PASS_DEF_GPUSPIRVATTACHTARGET |
27 | #include "mlir/Dialect/GPU/Transforms/Passes.h.inc" |
28 | } // namespace mlir |
29 | |
30 | using namespace mlir; |
31 | using namespace mlir::spirv; |
32 | |
33 | namespace { |
34 | struct SPIRVAttachTarget |
35 | : public impl::GpuSPIRVAttachTargetBase<SPIRVAttachTarget> { |
36 | using Base::Base; |
37 | |
38 | void runOnOperation() override; |
39 | |
40 | void getDependentDialects(DialectRegistry ®istry) const override { |
41 | registry.insert<spirv::SPIRVDialect>(); |
42 | } |
43 | }; |
44 | } // namespace |
45 | |
46 | void SPIRVAttachTarget::runOnOperation() { |
47 | OpBuilder builder(&getContext()); |
48 | auto versionSymbol = symbolizeVersion(spirvVersion); |
49 | if (!versionSymbol) |
50 | return signalPassFailure(); |
51 | auto apiSymbol = symbolizeClientAPI(clientApi); |
52 | if (!apiSymbol) |
53 | return signalPassFailure(); |
54 | auto vendorSymbol = symbolizeVendor(deviceVendor); |
55 | if (!vendorSymbol) |
56 | return signalPassFailure(); |
57 | auto deviceTypeSymbol = symbolizeDeviceType(deviceType); |
58 | if (!deviceTypeSymbol) |
59 | return signalPassFailure(); |
60 | // Set the default device ID if none was given |
61 | if (!deviceId.hasValue()) |
62 | deviceId = mlir::spirv::TargetEnvAttr::kUnknownDeviceID; |
63 | |
64 | Version version = versionSymbol.value(); |
65 | SmallVector<Capability, 4> capabilities; |
66 | SmallVector<Extension, 8> extensions; |
67 | for (const auto &cap : spirvCapabilities) { |
68 | auto capSymbol = symbolizeCapability(cap); |
69 | if (capSymbol) |
70 | capabilities.push_back(capSymbol.value()); |
71 | } |
72 | ArrayRef<Capability> caps(capabilities); |
73 | for (const auto &ext : spirvExtensions) { |
74 | auto extSymbol = symbolizeExtension(ext); |
75 | if (extSymbol) |
76 | extensions.push_back(extSymbol.value()); |
77 | } |
78 | ArrayRef<Extension> exts(extensions); |
79 | VerCapExtAttr vce = VerCapExtAttr::get(version, caps, exts, &getContext()); |
80 | auto target = TargetEnvAttr::get(vce, getDefaultResourceLimits(&getContext()), |
81 | apiSymbol.value(), vendorSymbol.value(), |
82 | deviceTypeSymbol.value(), deviceId); |
83 | llvm::Regex matcher(moduleMatcher); |
84 | getOperation()->walk([&](gpu::GPUModuleOp gpuModule) { |
85 | // Check if the name of the module matches. |
86 | if (!moduleMatcher.empty() && !matcher.match(gpuModule.getName())) |
87 | return; |
88 | // Create the target array. |
89 | SmallVector<Attribute> targets; |
90 | if (std::optional<ArrayAttr> attrs = gpuModule.getTargets()) |
91 | targets.append(attrs->getValue().begin(), attrs->getValue().end()); |
92 | targets.push_back(Elt: target); |
93 | // Remove any duplicate targets. |
94 | targets.erase(CS: std::unique(first: targets.begin(), last: targets.end()), CE: targets.end()); |
95 | // Update the target attribute array. |
96 | gpuModule.setTargetsAttr(builder.getArrayAttr(targets)); |
97 | }); |
98 | } |
99 | |