1//===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
10#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
11#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/Operation.h"
14#include "mlir/IR/SymbolTable.h"
15#include "mlir/Interfaces/FunctionInterfaces.h"
16#include <optional>
17
18using namespace mlir;
19
20//===----------------------------------------------------------------------===//
21// TargetEnv
22//===----------------------------------------------------------------------===//
23
24spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr)
25 : targetAttr(targetAttr) {
26 givenExtensions.insert_range(R: targetAttr.getExtensions());
27
28 // Add extensions implied by the current version.
29 givenExtensions.insert_range(
30 R: spirv::getImpliedExtensions(version: targetAttr.getVersion()));
31
32 for (spirv::Capability cap : targetAttr.getCapabilities()) {
33 givenCapabilities.insert(cap);
34
35 // Add capabilities implied by the current capability.
36 givenCapabilities.insert_range(spirv::getRecursiveImpliedCapabilities(cap));
37 }
38}
39
40spirv::Version spirv::TargetEnv::getVersion() const {
41 return targetAttr.getVersion();
42}
43
44bool spirv::TargetEnv::allows(spirv::Capability capability) const {
45 return givenCapabilities.count(V: capability);
46}
47
48std::optional<spirv::Capability>
49spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const {
50 const auto *chosen = llvm::find_if(Range&: caps, P: [this](spirv::Capability cap) {
51 return givenCapabilities.count(V: cap);
52 });
53 if (chosen != caps.end())
54 return *chosen;
55 return std::nullopt;
56}
57
58bool spirv::TargetEnv::allows(spirv::Extension extension) const {
59 return givenExtensions.count(V: extension);
60}
61
62std::optional<spirv::Extension>
63spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const {
64 const auto *chosen = llvm::find_if(Range&: exts, P: [this](spirv::Extension ext) {
65 return givenExtensions.count(V: ext);
66 });
67 if (chosen != exts.end())
68 return *chosen;
69 return std::nullopt;
70}
71
72spirv::Vendor spirv::TargetEnv::getVendorID() const {
73 return targetAttr.getVendorID();
74}
75
76spirv::DeviceType spirv::TargetEnv::getDeviceType() const {
77 return targetAttr.getDeviceType();
78}
79
80uint32_t spirv::TargetEnv::getDeviceID() const {
81 return targetAttr.getDeviceID();
82}
83
84spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const {
85 return targetAttr.getResourceLimits();
86}
87
88MLIRContext *spirv::TargetEnv::getContext() const {
89 return targetAttr.getContext();
90}
91
92//===----------------------------------------------------------------------===//
93// Utility functions
94//===----------------------------------------------------------------------===//
95
96StringRef spirv::getInterfaceVarABIAttrName() {
97 return "spirv.interface_var_abi";
98}
99
100spirv::InterfaceVarABIAttr
101spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
102 std::optional<spirv::StorageClass> storageClass,
103 MLIRContext *context) {
104 return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass,
105 context);
106}
107
108bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) {
109 for (spirv::Capability cap : targetAttr.getCapabilities()) {
110 if (cap == spirv::Capability::Kernel)
111 return false;
112 if (cap == spirv::Capability::Shader)
113 return true;
114 }
115 return false;
116}
117
118StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi"; }
119
120spirv::EntryPointABIAttr spirv::getEntryPointABIAttr(
121 MLIRContext *context, ArrayRef<int32_t> workgroupSize,
122 std::optional<int> subgroupSize, std::optional<int> targetWidth) {
123 DenseI32ArrayAttr workgroupSizeAttr;
124 if (!workgroupSize.empty()) {
125 assert(workgroupSize.size() == 3);
126 workgroupSizeAttr = DenseI32ArrayAttr::get(context, workgroupSize);
127 }
128 return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr, subgroupSize,
129 targetWidth);
130}
131
132spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
133 while (op && !isa<FunctionOpInterface>(Val: op))
134 op = op->getParentOp();
135 if (!op)
136 return {};
137
138 if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>(
139 spirv::getEntryPointABIAttrName()))
140 return attr;
141
142 return {};
143}
144
145DenseI32ArrayAttr spirv::lookupLocalWorkGroupSize(Operation *op) {
146 if (auto entryPoint = spirv::lookupEntryPointABI(op))
147 return entryPoint.getWorkgroupSize();
148
149 return {};
150}
151
152spirv::ResourceLimitsAttr
153spirv::getDefaultResourceLimits(MLIRContext *context) {
154 // All the fields have default values. Here we just provide a nicer way to
155 // construct a default resource limit attribute.
156 Builder b(context);
157 return spirv::ResourceLimitsAttr::get(
158 context,
159 /*max_compute_shared_memory_size=*/16384,
160 /*max_compute_workgroup_invocations=*/128,
161 /*max_compute_workgroup_size=*/b.getI32ArrayAttr({128, 128, 64}),
162 /*subgroup_size=*/32,
163 /*min_subgroup_size=*/std::nullopt,
164 /*max_subgroup_size=*/std::nullopt,
165 /*cooperative_matrix_properties_khr=*/ArrayAttr{},
166 /*cooperative_matrix_properties_nv=*/ArrayAttr{});
167}
168
169StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env"; }
170
171spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
172 auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0,
173 {spirv::Capability::Shader},
174 ArrayRef<Extension>(), context);
175 return spirv::TargetEnvAttr::get(
176 triple, spirv::getDefaultResourceLimits(context),
177 spirv::ClientAPI::Unknown, spirv::Vendor::Unknown,
178 spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID);
179}
180
181spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) {
182 while (op) {
183 op = SymbolTable::getNearestSymbolTable(from: op);
184 if (!op)
185 break;
186
187 if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
188 spirv::getTargetEnvAttrName()))
189 return attr;
190
191 op = op->getParentOp();
192 }
193
194 return {};
195}
196
197spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
198 if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op))
199 return attr;
200
201 return getDefaultTargetEnv(context: op->getContext());
202}
203
204spirv::AddressingModel
205spirv::getAddressingModel(spirv::TargetEnvAttr targetAttr,
206 bool use64bitAddress) {
207 for (spirv::Capability cap : targetAttr.getCapabilities()) {
208 if (cap == Capability::Kernel)
209 return use64bitAddress ? spirv::AddressingModel::Physical64
210 : spirv::AddressingModel::Physical32;
211 // TODO PhysicalStorageBuffer64 is hard-coded here, but some information
212 // should come from TargetEnvAttr to select between PhysicalStorageBuffer64
213 // and PhysicalStorageBuffer64EXT
214 if (cap == Capability::PhysicalStorageBufferAddresses)
215 return spirv::AddressingModel::PhysicalStorageBuffer64;
216 }
217 // Logical addressing doesn't need any capabilities so return it as default.
218 return spirv::AddressingModel::Logical;
219}
220
221FailureOr<spirv::ExecutionModel>
222spirv::getExecutionModel(spirv::TargetEnvAttr targetAttr) {
223 for (spirv::Capability cap : targetAttr.getCapabilities()) {
224 if (cap == spirv::Capability::Kernel)
225 return spirv::ExecutionModel::Kernel;
226 if (cap == spirv::Capability::Shader)
227 return spirv::ExecutionModel::GLCompute;
228 }
229 return failure();
230}
231
232FailureOr<spirv::MemoryModel>
233spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) {
234 for (spirv::Capability cap : targetAttr.getCapabilities()) {
235 if (cap == spirv::Capability::Kernel)
236 return spirv::MemoryModel::OpenCL;
237 if (cap == spirv::Capability::Shader)
238 return spirv::MemoryModel::GLSL450;
239 }
240 return failure();
241}
242

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp