1//===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
10#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
11#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/Operation.h"
14#include "mlir/IR/SymbolTable.h"
15#include "mlir/Interfaces/FunctionInterfaces.h"
16#include <optional>
17
18using namespace mlir;
19
20//===----------------------------------------------------------------------===//
21// TargetEnv
22//===----------------------------------------------------------------------===//
23
24spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr)
25 : targetAttr(targetAttr) {
26 for (spirv::Extension ext : targetAttr.getExtensions())
27 givenExtensions.insert(ext);
28
29 // Add extensions implied by the current version.
30 for (spirv::Extension ext :
31 spirv::getImpliedExtensions(version: targetAttr.getVersion()))
32 givenExtensions.insert(V: ext);
33
34 for (spirv::Capability cap : targetAttr.getCapabilities()) {
35 givenCapabilities.insert(cap);
36
37 // Add capabilities implied by the current capability.
38 for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
39 givenCapabilities.insert(c);
40 }
41}
42
43spirv::Version spirv::TargetEnv::getVersion() const {
44 return targetAttr.getVersion();
45}
46
47bool spirv::TargetEnv::allows(spirv::Capability capability) const {
48 return givenCapabilities.count(V: capability);
49}
50
51std::optional<spirv::Capability>
52spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const {
53 const auto *chosen = llvm::find_if(Range&: caps, P: [this](spirv::Capability cap) {
54 return givenCapabilities.count(V: cap);
55 });
56 if (chosen != caps.end())
57 return *chosen;
58 return std::nullopt;
59}
60
61bool spirv::TargetEnv::allows(spirv::Extension extension) const {
62 return givenExtensions.count(V: extension);
63}
64
65std::optional<spirv::Extension>
66spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const {
67 const auto *chosen = llvm::find_if(Range&: exts, P: [this](spirv::Extension ext) {
68 return givenExtensions.count(V: ext);
69 });
70 if (chosen != exts.end())
71 return *chosen;
72 return std::nullopt;
73}
74
75spirv::Vendor spirv::TargetEnv::getVendorID() const {
76 return targetAttr.getVendorID();
77}
78
79spirv::DeviceType spirv::TargetEnv::getDeviceType() const {
80 return targetAttr.getDeviceType();
81}
82
83uint32_t spirv::TargetEnv::getDeviceID() const {
84 return targetAttr.getDeviceID();
85}
86
87spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const {
88 return targetAttr.getResourceLimits();
89}
90
91MLIRContext *spirv::TargetEnv::getContext() const {
92 return targetAttr.getContext();
93}
94
95//===----------------------------------------------------------------------===//
96// Utility functions
97//===----------------------------------------------------------------------===//
98
99StringRef spirv::getInterfaceVarABIAttrName() {
100 return "spirv.interface_var_abi";
101}
102
103spirv::InterfaceVarABIAttr
104spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
105 std::optional<spirv::StorageClass> storageClass,
106 MLIRContext *context) {
107 return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass,
108 context);
109}
110
111bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) {
112 for (spirv::Capability cap : targetAttr.getCapabilities()) {
113 if (cap == spirv::Capability::Kernel)
114 return false;
115 if (cap == spirv::Capability::Shader)
116 return true;
117 }
118 return false;
119}
120
121StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi"; }
122
123spirv::EntryPointABIAttr spirv::getEntryPointABIAttr(
124 MLIRContext *context, ArrayRef<int32_t> workgroupSize,
125 std::optional<int> subgroupSize, std::optional<int> targetWidth) {
126 DenseI32ArrayAttr workgroupSizeAttr;
127 if (!workgroupSize.empty()) {
128 assert(workgroupSize.size() == 3);
129 workgroupSizeAttr = DenseI32ArrayAttr::get(context, workgroupSize);
130 }
131 return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr, subgroupSize,
132 targetWidth);
133}
134
135spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
136 while (op && !isa<FunctionOpInterface>(Val: op))
137 op = op->getParentOp();
138 if (!op)
139 return {};
140
141 if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>(
142 spirv::getEntryPointABIAttrName()))
143 return attr;
144
145 return {};
146}
147
148DenseI32ArrayAttr spirv::lookupLocalWorkGroupSize(Operation *op) {
149 if (auto entryPoint = spirv::lookupEntryPointABI(op))
150 return entryPoint.getWorkgroupSize();
151
152 return {};
153}
154
155spirv::ResourceLimitsAttr
156spirv::getDefaultResourceLimits(MLIRContext *context) {
157 // All the fields have default values. Here we just provide a nicer way to
158 // construct a default resource limit attribute.
159 Builder b(context);
160 return spirv::ResourceLimitsAttr::get(
161 context,
162 /*max_compute_shared_memory_size=*/16384,
163 /*max_compute_workgroup_invocations=*/128,
164 /*max_compute_workgroup_size=*/b.getI32ArrayAttr({128, 128, 64}),
165 /*subgroup_size=*/32,
166 /*min_subgroup_size=*/std::nullopt,
167 /*max_subgroup_size=*/std::nullopt,
168 /*cooperative_matrix_properties_khr=*/ArrayAttr{},
169 /*cooperative_matrix_properties_nv=*/ArrayAttr{});
170}
171
172StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; }
173
174spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
175 auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0,
176 {spirv::Capability::Shader},
177 ArrayRef<Extension>(), context);
178 return spirv::TargetEnvAttr::get(
179 triple, spirv::getDefaultResourceLimits(context),
180 spirv::ClientAPI::Unknown, spirv::Vendor::Unknown,
181 spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID);
182}
183
184spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) {
185 while (op) {
186 op = SymbolTable::getNearestSymbolTable(from: op);
187 if (!op)
188 break;
189
190 if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
191 spirv::getTargetEnvAttrName()))
192 return attr;
193
194 op = op->getParentOp();
195 }
196
197 return {};
198}
199
200spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
201 if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op))
202 return attr;
203
204 return getDefaultTargetEnv(context: op->getContext());
205}
206
207spirv::AddressingModel
208spirv::getAddressingModel(spirv::TargetEnvAttr targetAttr,
209 bool use64bitAddress) {
210 for (spirv::Capability cap : targetAttr.getCapabilities()) {
211 if (cap == Capability::Kernel)
212 return use64bitAddress ? spirv::AddressingModel::Physical64
213 : spirv::AddressingModel::Physical32;
214 // TODO PhysicalStorageBuffer64 is hard-coded here, but some information
215 // should come from TargetEnvAttr to select between PhysicalStorageBuffer64
216 // and PhysicalStorageBuffer64EXT
217 if (cap == Capability::PhysicalStorageBufferAddresses)
218 return spirv::AddressingModel::PhysicalStorageBuffer64;
219 }
220 // Logical addressing doesn't need any capabilities so return it as default.
221 return spirv::AddressingModel::Logical;
222}
223
224FailureOr<spirv::ExecutionModel>
225spirv::getExecutionModel(spirv::TargetEnvAttr targetAttr) {
226 for (spirv::Capability cap : targetAttr.getCapabilities()) {
227 if (cap == spirv::Capability::Kernel)
228 return spirv::ExecutionModel::Kernel;
229 if (cap == spirv::Capability::Shader)
230 return spirv::ExecutionModel::GLCompute;
231 }
232 return failure();
233}
234
235FailureOr<spirv::MemoryModel>
236spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) {
237 for (spirv::Capability cap : targetAttr.getCapabilities()) {
238 if (cap == spirv::Capability::Kernel)
239 return spirv::MemoryModel::OpenCL;
240 if (cap == spirv::Capability::Shader)
241 return spirv::MemoryModel::GLSL450;
242 }
243 return failure();
244}
245

source code of mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp