1 | //===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
10 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
11 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
12 | #include "mlir/IR/Builders.h" |
13 | #include "mlir/IR/Operation.h" |
14 | #include "mlir/IR/SymbolTable.h" |
15 | #include "mlir/Interfaces/FunctionInterfaces.h" |
16 | #include <optional> |
17 | |
18 | using namespace mlir; |
19 | |
20 | //===----------------------------------------------------------------------===// |
21 | // TargetEnv |
22 | //===----------------------------------------------------------------------===// |
23 | |
24 | spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr) |
25 | : targetAttr(targetAttr) { |
26 | for (spirv::Extension ext : targetAttr.getExtensions()) |
27 | givenExtensions.insert(ext); |
28 | |
29 | // Add extensions implied by the current version. |
30 | for (spirv::Extension ext : |
31 | spirv::getImpliedExtensions(version: targetAttr.getVersion())) |
32 | givenExtensions.insert(V: ext); |
33 | |
34 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
35 | givenCapabilities.insert(cap); |
36 | |
37 | // Add capabilities implied by the current capability. |
38 | for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap)) |
39 | givenCapabilities.insert(c); |
40 | } |
41 | } |
42 | |
43 | spirv::Version spirv::TargetEnv::getVersion() const { |
44 | return targetAttr.getVersion(); |
45 | } |
46 | |
47 | bool spirv::TargetEnv::allows(spirv::Capability capability) const { |
48 | return givenCapabilities.count(V: capability); |
49 | } |
50 | |
51 | std::optional<spirv::Capability> |
52 | spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const { |
53 | const auto *chosen = llvm::find_if(Range&: caps, P: [this](spirv::Capability cap) { |
54 | return givenCapabilities.count(V: cap); |
55 | }); |
56 | if (chosen != caps.end()) |
57 | return *chosen; |
58 | return std::nullopt; |
59 | } |
60 | |
61 | bool spirv::TargetEnv::allows(spirv::Extension extension) const { |
62 | return givenExtensions.count(V: extension); |
63 | } |
64 | |
65 | std::optional<spirv::Extension> |
66 | spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const { |
67 | const auto *chosen = llvm::find_if(Range&: exts, P: [this](spirv::Extension ext) { |
68 | return givenExtensions.count(V: ext); |
69 | }); |
70 | if (chosen != exts.end()) |
71 | return *chosen; |
72 | return std::nullopt; |
73 | } |
74 | |
75 | spirv::Vendor spirv::TargetEnv::getVendorID() const { |
76 | return targetAttr.getVendorID(); |
77 | } |
78 | |
79 | spirv::DeviceType spirv::TargetEnv::getDeviceType() const { |
80 | return targetAttr.getDeviceType(); |
81 | } |
82 | |
83 | uint32_t spirv::TargetEnv::getDeviceID() const { |
84 | return targetAttr.getDeviceID(); |
85 | } |
86 | |
87 | spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const { |
88 | return targetAttr.getResourceLimits(); |
89 | } |
90 | |
91 | MLIRContext *spirv::TargetEnv::getContext() const { |
92 | return targetAttr.getContext(); |
93 | } |
94 | |
95 | //===----------------------------------------------------------------------===// |
96 | // Utility functions |
97 | //===----------------------------------------------------------------------===// |
98 | |
99 | StringRef spirv::getInterfaceVarABIAttrName() { |
100 | return "spirv.interface_var_abi" ; |
101 | } |
102 | |
103 | spirv::InterfaceVarABIAttr |
104 | spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, |
105 | std::optional<spirv::StorageClass> storageClass, |
106 | MLIRContext *context) { |
107 | return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass, |
108 | context); |
109 | } |
110 | |
111 | bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) { |
112 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
113 | if (cap == spirv::Capability::Kernel) |
114 | return false; |
115 | if (cap == spirv::Capability::Shader) |
116 | return true; |
117 | } |
118 | return false; |
119 | } |
120 | |
121 | StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi" ; } |
122 | |
123 | spirv::EntryPointABIAttr spirv::getEntryPointABIAttr( |
124 | MLIRContext *context, ArrayRef<int32_t> workgroupSize, |
125 | std::optional<int> subgroupSize, std::optional<int> targetWidth) { |
126 | DenseI32ArrayAttr workgroupSizeAttr; |
127 | if (!workgroupSize.empty()) { |
128 | assert(workgroupSize.size() == 3); |
129 | workgroupSizeAttr = DenseI32ArrayAttr::get(context, workgroupSize); |
130 | } |
131 | return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr, subgroupSize, |
132 | targetWidth); |
133 | } |
134 | |
135 | spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) { |
136 | while (op && !isa<FunctionOpInterface>(Val: op)) |
137 | op = op->getParentOp(); |
138 | if (!op) |
139 | return {}; |
140 | |
141 | if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>( |
142 | spirv::getEntryPointABIAttrName())) |
143 | return attr; |
144 | |
145 | return {}; |
146 | } |
147 | |
148 | DenseI32ArrayAttr spirv::lookupLocalWorkGroupSize(Operation *op) { |
149 | if (auto entryPoint = spirv::lookupEntryPointABI(op)) |
150 | return entryPoint.getWorkgroupSize(); |
151 | |
152 | return {}; |
153 | } |
154 | |
155 | spirv::ResourceLimitsAttr |
156 | spirv::getDefaultResourceLimits(MLIRContext *context) { |
157 | // All the fields have default values. Here we just provide a nicer way to |
158 | // construct a default resource limit attribute. |
159 | Builder b(context); |
160 | return spirv::ResourceLimitsAttr::get( |
161 | context, |
162 | /*max_compute_shared_memory_size=*/16384, |
163 | /*max_compute_workgroup_invocations=*/128, |
164 | /*max_compute_workgroup_size=*/b.getI32ArrayAttr({128, 128, 64}), |
165 | /*subgroup_size=*/32, |
166 | /*min_subgroup_size=*/std::nullopt, |
167 | /*max_subgroup_size=*/std::nullopt, |
168 | /*cooperative_matrix_properties_khr=*/ArrayAttr{}, |
169 | /*cooperative_matrix_properties_nv=*/ArrayAttr{}); |
170 | } |
171 | |
172 | StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env" ; } |
173 | |
174 | spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) { |
175 | auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0, |
176 | {spirv::Capability::Shader}, |
177 | ArrayRef<Extension>(), context); |
178 | return spirv::TargetEnvAttr::get( |
179 | triple, spirv::getDefaultResourceLimits(context), |
180 | spirv::ClientAPI::Unknown, spirv::Vendor::Unknown, |
181 | spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID); |
182 | } |
183 | |
184 | spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) { |
185 | while (op) { |
186 | op = SymbolTable::getNearestSymbolTable(from: op); |
187 | if (!op) |
188 | break; |
189 | |
190 | if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>( |
191 | spirv::getTargetEnvAttrName())) |
192 | return attr; |
193 | |
194 | op = op->getParentOp(); |
195 | } |
196 | |
197 | return {}; |
198 | } |
199 | |
200 | spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) { |
201 | if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) |
202 | return attr; |
203 | |
204 | return getDefaultTargetEnv(context: op->getContext()); |
205 | } |
206 | |
207 | spirv::AddressingModel |
208 | spirv::getAddressingModel(spirv::TargetEnvAttr targetAttr, |
209 | bool use64bitAddress) { |
210 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
211 | if (cap == Capability::Kernel) |
212 | return use64bitAddress ? spirv::AddressingModel::Physical64 |
213 | : spirv::AddressingModel::Physical32; |
214 | // TODO PhysicalStorageBuffer64 is hard-coded here, but some information |
215 | // should come from TargetEnvAttr to select between PhysicalStorageBuffer64 |
216 | // and PhysicalStorageBuffer64EXT |
217 | if (cap == Capability::PhysicalStorageBufferAddresses) |
218 | return spirv::AddressingModel::PhysicalStorageBuffer64; |
219 | } |
220 | // Logical addressing doesn't need any capabilities so return it as default. |
221 | return spirv::AddressingModel::Logical; |
222 | } |
223 | |
224 | FailureOr<spirv::ExecutionModel> |
225 | spirv::getExecutionModel(spirv::TargetEnvAttr targetAttr) { |
226 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
227 | if (cap == spirv::Capability::Kernel) |
228 | return spirv::ExecutionModel::Kernel; |
229 | if (cap == spirv::Capability::Shader) |
230 | return spirv::ExecutionModel::GLCompute; |
231 | } |
232 | return failure(); |
233 | } |
234 | |
235 | FailureOr<spirv::MemoryModel> |
236 | spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) { |
237 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
238 | if (cap == spirv::Capability::Kernel) |
239 | return spirv::MemoryModel::OpenCL; |
240 | if (cap == spirv::Capability::Shader) |
241 | return spirv::MemoryModel::GLSL450; |
242 | } |
243 | return failure(); |
244 | } |
245 | |