1//===- SPIRVAttachTarget.cpp - Attach an SPIR-V target --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the `GPUSPIRVAttachTarget` pass, attaching
10// `#spirv.target_env` attributes to GPU modules.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/GPU/Transforms/Passes.h"
15
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/Pass/Pass.h"
22#include "mlir/Target/SPIRV/Target.h"
23#include "llvm/Support/Regex.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_GPUSPIRVATTACHTARGET
27#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
28} // namespace mlir
29
30using namespace mlir;
31using namespace mlir::spirv;
32
33namespace {
34struct SPIRVAttachTarget
35 : public impl::GpuSPIRVAttachTargetBase<SPIRVAttachTarget> {
36 using Base::Base;
37
38 void runOnOperation() override;
39
40 void getDependentDialects(DialectRegistry &registry) const override {
41 registry.insert<spirv::SPIRVDialect>();
42 }
43};
44} // namespace
45
46void SPIRVAttachTarget::runOnOperation() {
47 OpBuilder builder(&getContext());
48 auto versionSymbol = symbolizeVersion(spirvVersion);
49 if (!versionSymbol)
50 return signalPassFailure();
51 auto apiSymbol = symbolizeClientAPI(clientApi);
52 if (!apiSymbol)
53 return signalPassFailure();
54 auto vendorSymbol = symbolizeVendor(deviceVendor);
55 if (!vendorSymbol)
56 return signalPassFailure();
57 auto deviceTypeSymbol = symbolizeDeviceType(deviceType);
58 if (!deviceTypeSymbol)
59 return signalPassFailure();
60 // Set the default device ID if none was given
61 if (!deviceId.hasValue())
62 deviceId = mlir::spirv::TargetEnvAttr::kUnknownDeviceID;
63
64 Version version = versionSymbol.value();
65 SmallVector<Capability, 4> capabilities;
66 SmallVector<Extension, 8> extensions;
67 for (const auto &cap : spirvCapabilities) {
68 auto capSymbol = symbolizeCapability(cap);
69 if (capSymbol)
70 capabilities.push_back(capSymbol.value());
71 }
72 ArrayRef<Capability> caps(capabilities);
73 for (const auto &ext : spirvExtensions) {
74 auto extSymbol = symbolizeExtension(ext);
75 if (extSymbol)
76 extensions.push_back(extSymbol.value());
77 }
78 ArrayRef<Extension> exts(extensions);
79 VerCapExtAttr vce = VerCapExtAttr::get(version, caps, exts, &getContext());
80 auto target = TargetEnvAttr::get(vce, getDefaultResourceLimits(&getContext()),
81 apiSymbol.value(), vendorSymbol.value(),
82 deviceTypeSymbol.value(), deviceId);
83 llvm::Regex matcher(moduleMatcher);
84 getOperation()->walk([&](gpu::GPUModuleOp gpuModule) {
85 // Check if the name of the module matches.
86 if (!moduleMatcher.empty() && !matcher.match(gpuModule.getName()))
87 return;
88 // Create the target array.
89 SmallVector<Attribute> targets;
90 if (std::optional<ArrayAttr> attrs = gpuModule.getTargets())
91 targets.append(attrs->getValue().begin(), attrs->getValue().end());
92 targets.push_back(Elt: target);
93 // Remove any duplicate targets.
94 targets.erase(CS: std::unique(first: targets.begin(), last: targets.end()), CE: targets.end());
95 // Update the target attribute array.
96 gpuModule.setTargetsAttr(builder.getArrayAttr(targets));
97 });
98}
99

source code of mlir/lib/Dialect/GPU/Transforms/SPIRVAttachTarget.cpp