| 1 | //===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
| 10 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
| 11 | #include "mlir/IR/Builders.h" |
| 12 | #include "mlir/IR/Operation.h" |
| 13 | #include "mlir/IR/SymbolTable.h" |
| 14 | #include "mlir/Interfaces/FunctionInterfaces.h" |
| 15 | #include <optional> |
| 16 | |
| 17 | using namespace mlir; |
| 18 | |
| 19 | //===----------------------------------------------------------------------===// |
| 20 | // TargetEnv |
| 21 | //===----------------------------------------------------------------------===// |
| 22 | |
| 23 | spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr) |
| 24 | : targetAttr(targetAttr) { |
| 25 | givenExtensions.insert_range(R: targetAttr.getExtensions()); |
| 26 | |
| 27 | // Add extensions implied by the current version. |
| 28 | givenExtensions.insert_range( |
| 29 | R: spirv::getImpliedExtensions(version: targetAttr.getVersion())); |
| 30 | |
| 31 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
| 32 | givenCapabilities.insert(V: cap); |
| 33 | |
| 34 | // Add capabilities implied by the current capability. |
| 35 | givenCapabilities.insert_range(R: spirv::getRecursiveImpliedCapabilities(cap)); |
| 36 | } |
| 37 | } |
| 38 | |
| 39 | spirv::Version spirv::TargetEnv::getVersion() const { |
| 40 | return targetAttr.getVersion(); |
| 41 | } |
| 42 | |
| 43 | bool spirv::TargetEnv::allows(spirv::Capability capability) const { |
| 44 | return givenCapabilities.count(V: capability); |
| 45 | } |
| 46 | |
| 47 | std::optional<spirv::Capability> |
| 48 | spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const { |
| 49 | const auto *chosen = llvm::find_if(Range&: caps, P: [this](spirv::Capability cap) { |
| 50 | return givenCapabilities.count(V: cap); |
| 51 | }); |
| 52 | if (chosen != caps.end()) |
| 53 | return *chosen; |
| 54 | return std::nullopt; |
| 55 | } |
| 56 | |
| 57 | bool spirv::TargetEnv::allows(spirv::Extension extension) const { |
| 58 | return givenExtensions.count(V: extension); |
| 59 | } |
| 60 | |
| 61 | std::optional<spirv::Extension> |
| 62 | spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const { |
| 63 | const auto *chosen = llvm::find_if(Range&: exts, P: [this](spirv::Extension ext) { |
| 64 | return givenExtensions.count(V: ext); |
| 65 | }); |
| 66 | if (chosen != exts.end()) |
| 67 | return *chosen; |
| 68 | return std::nullopt; |
| 69 | } |
| 70 | |
| 71 | spirv::Vendor spirv::TargetEnv::getVendorID() const { |
| 72 | return targetAttr.getVendorID(); |
| 73 | } |
| 74 | |
| 75 | spirv::DeviceType spirv::TargetEnv::getDeviceType() const { |
| 76 | return targetAttr.getDeviceType(); |
| 77 | } |
| 78 | |
| 79 | uint32_t spirv::TargetEnv::getDeviceID() const { |
| 80 | return targetAttr.getDeviceID(); |
| 81 | } |
| 82 | |
| 83 | spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const { |
| 84 | return targetAttr.getResourceLimits(); |
| 85 | } |
| 86 | |
| 87 | MLIRContext *spirv::TargetEnv::getContext() const { |
| 88 | return targetAttr.getContext(); |
| 89 | } |
| 90 | |
| 91 | //===----------------------------------------------------------------------===// |
| 92 | // Utility functions |
| 93 | //===----------------------------------------------------------------------===// |
| 94 | |
| 95 | StringRef spirv::getInterfaceVarABIAttrName() { |
| 96 | return "spirv.interface_var_abi" ; |
| 97 | } |
| 98 | |
| 99 | spirv::InterfaceVarABIAttr |
| 100 | spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, |
| 101 | std::optional<spirv::StorageClass> storageClass, |
| 102 | MLIRContext *context) { |
| 103 | return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass, |
| 104 | context); |
| 105 | } |
| 106 | |
| 107 | bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) { |
| 108 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
| 109 | if (cap == spirv::Capability::Kernel) |
| 110 | return false; |
| 111 | if (cap == spirv::Capability::Shader) |
| 112 | return true; |
| 113 | } |
| 114 | return false; |
| 115 | } |
| 116 | |
| 117 | StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi" ; } |
| 118 | |
| 119 | spirv::EntryPointABIAttr spirv::getEntryPointABIAttr( |
| 120 | MLIRContext *context, ArrayRef<int32_t> workgroupSize, |
| 121 | std::optional<int> subgroupSize, std::optional<int> targetWidth) { |
| 122 | DenseI32ArrayAttr workgroupSizeAttr; |
| 123 | if (!workgroupSize.empty()) { |
| 124 | assert(workgroupSize.size() == 3); |
| 125 | workgroupSizeAttr = DenseI32ArrayAttr::get(context, content: workgroupSize); |
| 126 | } |
| 127 | return spirv::EntryPointABIAttr::get(context, workgroup_size: workgroupSizeAttr, subgroup_size: subgroupSize, |
| 128 | target_width: targetWidth); |
| 129 | } |
| 130 | |
| 131 | spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) { |
| 132 | while (op && !isa<FunctionOpInterface>(Val: op)) |
| 133 | op = op->getParentOp(); |
| 134 | if (!op) |
| 135 | return {}; |
| 136 | |
| 137 | if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>( |
| 138 | name: spirv::getEntryPointABIAttrName())) |
| 139 | return attr; |
| 140 | |
| 141 | return {}; |
| 142 | } |
| 143 | |
| 144 | DenseI32ArrayAttr spirv::lookupLocalWorkGroupSize(Operation *op) { |
| 145 | if (auto entryPoint = spirv::lookupEntryPointABI(op)) |
| 146 | return entryPoint.getWorkgroupSize(); |
| 147 | |
| 148 | return {}; |
| 149 | } |
| 150 | |
| 151 | spirv::ResourceLimitsAttr |
| 152 | spirv::getDefaultResourceLimits(MLIRContext *context) { |
| 153 | // All the fields have default values. Here we just provide a nicer way to |
| 154 | // construct a default resource limit attribute. |
| 155 | Builder b(context); |
| 156 | return spirv::ResourceLimitsAttr::get( |
| 157 | context, |
| 158 | /*max_compute_shared_memory_size=*/16384, |
| 159 | /*max_compute_workgroup_invocations=*/128, |
| 160 | /*max_compute_workgroup_size=*/b.getI32ArrayAttr(values: {128, 128, 64}), |
| 161 | /*subgroup_size=*/32, |
| 162 | /*min_subgroup_size=*/std::nullopt, |
| 163 | /*max_subgroup_size=*/std::nullopt, |
| 164 | /*cooperative_matrix_properties_khr=*/ArrayAttr{}, |
| 165 | /*cooperative_matrix_properties_nv=*/ArrayAttr{}); |
| 166 | } |
| 167 | |
| 168 | StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env" ; } |
| 169 | |
| 170 | spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) { |
| 171 | auto triple = spirv::VerCapExtAttr::get(version: spirv::Version::V_1_0, |
| 172 | capabilities: {spirv::Capability::Shader}, |
| 173 | extensions: ArrayRef<Extension>(), context); |
| 174 | return spirv::TargetEnvAttr::get( |
| 175 | triple, limits: spirv::getDefaultResourceLimits(context), |
| 176 | clientAPI: spirv::ClientAPI::Unknown, vendorID: spirv::Vendor::Unknown, |
| 177 | deviceType: spirv::DeviceType::Unknown, deviceId: spirv::TargetEnvAttr::kUnknownDeviceID); |
| 178 | } |
| 179 | |
| 180 | spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) { |
| 181 | while (op) { |
| 182 | op = SymbolTable::getNearestSymbolTable(from: op); |
| 183 | if (!op) |
| 184 | break; |
| 185 | |
| 186 | if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>( |
| 187 | name: spirv::getTargetEnvAttrName())) |
| 188 | return attr; |
| 189 | |
| 190 | op = op->getParentOp(); |
| 191 | } |
| 192 | |
| 193 | return {}; |
| 194 | } |
| 195 | |
| 196 | spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) { |
| 197 | if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) |
| 198 | return attr; |
| 199 | |
| 200 | return getDefaultTargetEnv(context: op->getContext()); |
| 201 | } |
| 202 | |
| 203 | spirv::AddressingModel |
| 204 | spirv::getAddressingModel(spirv::TargetEnvAttr targetAttr, |
| 205 | bool use64bitAddress) { |
| 206 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
| 207 | if (cap == Capability::Kernel) |
| 208 | return use64bitAddress ? spirv::AddressingModel::Physical64 |
| 209 | : spirv::AddressingModel::Physical32; |
| 210 | // TODO PhysicalStorageBuffer64 is hard-coded here, but some information |
| 211 | // should come from TargetEnvAttr to select between PhysicalStorageBuffer64 |
| 212 | // and PhysicalStorageBuffer64EXT |
| 213 | if (cap == Capability::PhysicalStorageBufferAddresses) |
| 214 | return spirv::AddressingModel::PhysicalStorageBuffer64; |
| 215 | } |
| 216 | // Logical addressing doesn't need any capabilities so return it as default. |
| 217 | return spirv::AddressingModel::Logical; |
| 218 | } |
| 219 | |
| 220 | FailureOr<spirv::ExecutionModel> |
| 221 | spirv::getExecutionModel(spirv::TargetEnvAttr targetAttr) { |
| 222 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
| 223 | if (cap == spirv::Capability::Kernel) |
| 224 | return spirv::ExecutionModel::Kernel; |
| 225 | if (cap == spirv::Capability::Shader) |
| 226 | return spirv::ExecutionModel::GLCompute; |
| 227 | } |
| 228 | return failure(); |
| 229 | } |
| 230 | |
| 231 | FailureOr<spirv::MemoryModel> |
| 232 | spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) { |
| 233 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
| 234 | if (cap == spirv::Capability::Kernel) |
| 235 | return spirv::MemoryModel::OpenCL; |
| 236 | if (cap == spirv::Capability::Shader) |
| 237 | return spirv::MemoryModel::GLSL450; |
| 238 | } |
| 239 | return failure(); |
| 240 | } |
| 241 | |