1 | //===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
10 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
11 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
12 | #include "mlir/IR/Builders.h" |
13 | #include "mlir/IR/Operation.h" |
14 | #include "mlir/IR/SymbolTable.h" |
15 | #include "mlir/Interfaces/FunctionInterfaces.h" |
16 | #include <optional> |
17 | |
18 | using namespace mlir; |
19 | |
20 | //===----------------------------------------------------------------------===// |
21 | // TargetEnv |
22 | //===----------------------------------------------------------------------===// |
23 | |
24 | spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr) |
25 | : targetAttr(targetAttr) { |
26 | givenExtensions.insert_range(R: targetAttr.getExtensions()); |
27 | |
28 | // Add extensions implied by the current version. |
29 | givenExtensions.insert_range( |
30 | R: spirv::getImpliedExtensions(version: targetAttr.getVersion())); |
31 | |
32 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
33 | givenCapabilities.insert(cap); |
34 | |
35 | // Add capabilities implied by the current capability. |
36 | givenCapabilities.insert_range(spirv::getRecursiveImpliedCapabilities(cap)); |
37 | } |
38 | } |
39 | |
40 | spirv::Version spirv::TargetEnv::getVersion() const { |
41 | return targetAttr.getVersion(); |
42 | } |
43 | |
44 | bool spirv::TargetEnv::allows(spirv::Capability capability) const { |
45 | return givenCapabilities.count(V: capability); |
46 | } |
47 | |
48 | std::optional<spirv::Capability> |
49 | spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const { |
50 | const auto *chosen = llvm::find_if(Range&: caps, P: [this](spirv::Capability cap) { |
51 | return givenCapabilities.count(V: cap); |
52 | }); |
53 | if (chosen != caps.end()) |
54 | return *chosen; |
55 | return std::nullopt; |
56 | } |
57 | |
58 | bool spirv::TargetEnv::allows(spirv::Extension extension) const { |
59 | return givenExtensions.count(V: extension); |
60 | } |
61 | |
62 | std::optional<spirv::Extension> |
63 | spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const { |
64 | const auto *chosen = llvm::find_if(Range&: exts, P: [this](spirv::Extension ext) { |
65 | return givenExtensions.count(V: ext); |
66 | }); |
67 | if (chosen != exts.end()) |
68 | return *chosen; |
69 | return std::nullopt; |
70 | } |
71 | |
72 | spirv::Vendor spirv::TargetEnv::getVendorID() const { |
73 | return targetAttr.getVendorID(); |
74 | } |
75 | |
76 | spirv::DeviceType spirv::TargetEnv::getDeviceType() const { |
77 | return targetAttr.getDeviceType(); |
78 | } |
79 | |
80 | uint32_t spirv::TargetEnv::getDeviceID() const { |
81 | return targetAttr.getDeviceID(); |
82 | } |
83 | |
84 | spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const { |
85 | return targetAttr.getResourceLimits(); |
86 | } |
87 | |
88 | MLIRContext *spirv::TargetEnv::getContext() const { |
89 | return targetAttr.getContext(); |
90 | } |
91 | |
92 | //===----------------------------------------------------------------------===// |
93 | // Utility functions |
94 | //===----------------------------------------------------------------------===// |
95 | |
96 | StringRef spirv::getInterfaceVarABIAttrName() { |
97 | return "spirv.interface_var_abi"; |
98 | } |
99 | |
100 | spirv::InterfaceVarABIAttr |
101 | spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, |
102 | std::optional<spirv::StorageClass> storageClass, |
103 | MLIRContext *context) { |
104 | return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass, |
105 | context); |
106 | } |
107 | |
108 | bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) { |
109 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
110 | if (cap == spirv::Capability::Kernel) |
111 | return false; |
112 | if (cap == spirv::Capability::Shader) |
113 | return true; |
114 | } |
115 | return false; |
116 | } |
117 | |
118 | StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi"; } |
119 | |
120 | spirv::EntryPointABIAttr spirv::getEntryPointABIAttr( |
121 | MLIRContext *context, ArrayRef<int32_t> workgroupSize, |
122 | std::optional<int> subgroupSize, std::optional<int> targetWidth) { |
123 | DenseI32ArrayAttr workgroupSizeAttr; |
124 | if (!workgroupSize.empty()) { |
125 | assert(workgroupSize.size() == 3); |
126 | workgroupSizeAttr = DenseI32ArrayAttr::get(context, workgroupSize); |
127 | } |
128 | return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr, subgroupSize, |
129 | targetWidth); |
130 | } |
131 | |
132 | spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) { |
133 | while (op && !isa<FunctionOpInterface>(Val: op)) |
134 | op = op->getParentOp(); |
135 | if (!op) |
136 | return {}; |
137 | |
138 | if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>( |
139 | spirv::getEntryPointABIAttrName())) |
140 | return attr; |
141 | |
142 | return {}; |
143 | } |
144 | |
145 | DenseI32ArrayAttr spirv::lookupLocalWorkGroupSize(Operation *op) { |
146 | if (auto entryPoint = spirv::lookupEntryPointABI(op)) |
147 | return entryPoint.getWorkgroupSize(); |
148 | |
149 | return {}; |
150 | } |
151 | |
152 | spirv::ResourceLimitsAttr |
153 | spirv::getDefaultResourceLimits(MLIRContext *context) { |
154 | // All the fields have default values. Here we just provide a nicer way to |
155 | // construct a default resource limit attribute. |
156 | Builder b(context); |
157 | return spirv::ResourceLimitsAttr::get( |
158 | context, |
159 | /*max_compute_shared_memory_size=*/16384, |
160 | /*max_compute_workgroup_invocations=*/128, |
161 | /*max_compute_workgroup_size=*/b.getI32ArrayAttr({128, 128, 64}), |
162 | /*subgroup_size=*/32, |
163 | /*min_subgroup_size=*/std::nullopt, |
164 | /*max_subgroup_size=*/std::nullopt, |
165 | /*cooperative_matrix_properties_khr=*/ArrayAttr{}, |
166 | /*cooperative_matrix_properties_nv=*/ArrayAttr{}); |
167 | } |
168 | |
169 | StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; } |
170 | |
171 | spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) { |
172 | auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0, |
173 | {spirv::Capability::Shader}, |
174 | ArrayRef<Extension>(), context); |
175 | return spirv::TargetEnvAttr::get( |
176 | triple, spirv::getDefaultResourceLimits(context), |
177 | spirv::ClientAPI::Unknown, spirv::Vendor::Unknown, |
178 | spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID); |
179 | } |
180 | |
181 | spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) { |
182 | while (op) { |
183 | op = SymbolTable::getNearestSymbolTable(from: op); |
184 | if (!op) |
185 | break; |
186 | |
187 | if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>( |
188 | spirv::getTargetEnvAttrName())) |
189 | return attr; |
190 | |
191 | op = op->getParentOp(); |
192 | } |
193 | |
194 | return {}; |
195 | } |
196 | |
197 | spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) { |
198 | if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) |
199 | return attr; |
200 | |
201 | return getDefaultTargetEnv(context: op->getContext()); |
202 | } |
203 | |
204 | spirv::AddressingModel |
205 | spirv::getAddressingModel(spirv::TargetEnvAttr targetAttr, |
206 | bool use64bitAddress) { |
207 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
208 | if (cap == Capability::Kernel) |
209 | return use64bitAddress ? spirv::AddressingModel::Physical64 |
210 | : spirv::AddressingModel::Physical32; |
211 | // TODO PhysicalStorageBuffer64 is hard-coded here, but some information |
212 | // should come from TargetEnvAttr to select between PhysicalStorageBuffer64 |
213 | // and PhysicalStorageBuffer64EXT |
214 | if (cap == Capability::PhysicalStorageBufferAddresses) |
215 | return spirv::AddressingModel::PhysicalStorageBuffer64; |
216 | } |
217 | // Logical addressing doesn't need any capabilities so return it as default. |
218 | return spirv::AddressingModel::Logical; |
219 | } |
220 | |
221 | FailureOr<spirv::ExecutionModel> |
222 | spirv::getExecutionModel(spirv::TargetEnvAttr targetAttr) { |
223 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
224 | if (cap == spirv::Capability::Kernel) |
225 | return spirv::ExecutionModel::Kernel; |
226 | if (cap == spirv::Capability::Shader) |
227 | return spirv::ExecutionModel::GLCompute; |
228 | } |
229 | return failure(); |
230 | } |
231 | |
232 | FailureOr<spirv::MemoryModel> |
233 | spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) { |
234 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
235 | if (cap == spirv::Capability::Kernel) |
236 | return spirv::MemoryModel::OpenCL; |
237 | if (cap == spirv::Capability::Shader) |
238 | return spirv::MemoryModel::GLSL450; |
239 | } |
240 | return failure(); |
241 | } |
242 |
Definitions
- TargetEnv
- getVersion
- allows
- allows
- allows
- allows
- getVendorID
- getDeviceType
- getDeviceID
- getResourceLimits
- getContext
- getInterfaceVarABIAttrName
- getInterfaceVarABIAttr
- needsInterfaceVarABIAttrs
- getEntryPointABIAttrName
- getEntryPointABIAttr
- lookupEntryPointABI
- lookupLocalWorkGroupSize
- getDefaultResourceLimits
- getTargetEnvAttrName
- getDefaultTargetEnv
- lookupTargetEnv
- lookupTargetEnvOrDefault
- getAddressingModel
- getExecutionModel
Learn to use CMake with our Intro Training
Find out more