| 1 | //===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
| 10 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
| 11 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
| 12 | #include "mlir/IR/Builders.h" |
| 13 | #include "mlir/IR/Operation.h" |
| 14 | #include "mlir/IR/SymbolTable.h" |
| 15 | #include "mlir/Interfaces/FunctionInterfaces.h" |
| 16 | #include <optional> |
| 17 | |
| 18 | using namespace mlir; |
| 19 | |
| 20 | //===----------------------------------------------------------------------===// |
| 21 | // TargetEnv |
| 22 | //===----------------------------------------------------------------------===// |
| 23 | |
| 24 | spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr) |
| 25 | : targetAttr(targetAttr) { |
| 26 | givenExtensions.insert_range(R: targetAttr.getExtensions()); |
| 27 | |
| 28 | // Add extensions implied by the current version. |
| 29 | givenExtensions.insert_range( |
| 30 | R: spirv::getImpliedExtensions(version: targetAttr.getVersion())); |
| 31 | |
| 32 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
| 33 | givenCapabilities.insert(cap); |
| 34 | |
| 35 | // Add capabilities implied by the current capability. |
| 36 | givenCapabilities.insert_range(spirv::getRecursiveImpliedCapabilities(cap)); |
| 37 | } |
| 38 | } |
| 39 | |
| 40 | spirv::Version spirv::TargetEnv::getVersion() const { |
| 41 | return targetAttr.getVersion(); |
| 42 | } |
| 43 | |
| 44 | bool spirv::TargetEnv::allows(spirv::Capability capability) const { |
| 45 | return givenCapabilities.count(V: capability); |
| 46 | } |
| 47 | |
| 48 | std::optional<spirv::Capability> |
| 49 | spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const { |
| 50 | const auto *chosen = llvm::find_if(Range&: caps, P: [this](spirv::Capability cap) { |
| 51 | return givenCapabilities.count(V: cap); |
| 52 | }); |
| 53 | if (chosen != caps.end()) |
| 54 | return *chosen; |
| 55 | return std::nullopt; |
| 56 | } |
| 57 | |
| 58 | bool spirv::TargetEnv::allows(spirv::Extension extension) const { |
| 59 | return givenExtensions.count(V: extension); |
| 60 | } |
| 61 | |
| 62 | std::optional<spirv::Extension> |
| 63 | spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const { |
| 64 | const auto *chosen = llvm::find_if(Range&: exts, P: [this](spirv::Extension ext) { |
| 65 | return givenExtensions.count(V: ext); |
| 66 | }); |
| 67 | if (chosen != exts.end()) |
| 68 | return *chosen; |
| 69 | return std::nullopt; |
| 70 | } |
| 71 | |
| 72 | spirv::Vendor spirv::TargetEnv::getVendorID() const { |
| 73 | return targetAttr.getVendorID(); |
| 74 | } |
| 75 | |
| 76 | spirv::DeviceType spirv::TargetEnv::getDeviceType() const { |
| 77 | return targetAttr.getDeviceType(); |
| 78 | } |
| 79 | |
| 80 | uint32_t spirv::TargetEnv::getDeviceID() const { |
| 81 | return targetAttr.getDeviceID(); |
| 82 | } |
| 83 | |
| 84 | spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const { |
| 85 | return targetAttr.getResourceLimits(); |
| 86 | } |
| 87 | |
| 88 | MLIRContext *spirv::TargetEnv::getContext() const { |
| 89 | return targetAttr.getContext(); |
| 90 | } |
| 91 | |
| 92 | //===----------------------------------------------------------------------===// |
| 93 | // Utility functions |
| 94 | //===----------------------------------------------------------------------===// |
| 95 | |
| 96 | StringRef spirv::getInterfaceVarABIAttrName() { |
| 97 | return "spirv.interface_var_abi" ; |
| 98 | } |
| 99 | |
| 100 | spirv::InterfaceVarABIAttr |
| 101 | spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, |
| 102 | std::optional<spirv::StorageClass> storageClass, |
| 103 | MLIRContext *context) { |
| 104 | return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass, |
| 105 | context); |
| 106 | } |
| 107 | |
| 108 | bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) { |
| 109 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
| 110 | if (cap == spirv::Capability::Kernel) |
| 111 | return false; |
| 112 | if (cap == spirv::Capability::Shader) |
| 113 | return true; |
| 114 | } |
| 115 | return false; |
| 116 | } |
| 117 | |
| 118 | StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi" ; } |
| 119 | |
| 120 | spirv::EntryPointABIAttr spirv::getEntryPointABIAttr( |
| 121 | MLIRContext *context, ArrayRef<int32_t> workgroupSize, |
| 122 | std::optional<int> subgroupSize, std::optional<int> targetWidth) { |
| 123 | DenseI32ArrayAttr workgroupSizeAttr; |
| 124 | if (!workgroupSize.empty()) { |
| 125 | assert(workgroupSize.size() == 3); |
| 126 | workgroupSizeAttr = DenseI32ArrayAttr::get(context, workgroupSize); |
| 127 | } |
| 128 | return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr, subgroupSize, |
| 129 | targetWidth); |
| 130 | } |
| 131 | |
| 132 | spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) { |
| 133 | while (op && !isa<FunctionOpInterface>(Val: op)) |
| 134 | op = op->getParentOp(); |
| 135 | if (!op) |
| 136 | return {}; |
| 137 | |
| 138 | if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>( |
| 139 | spirv::getEntryPointABIAttrName())) |
| 140 | return attr; |
| 141 | |
| 142 | return {}; |
| 143 | } |
| 144 | |
| 145 | DenseI32ArrayAttr spirv::lookupLocalWorkGroupSize(Operation *op) { |
| 146 | if (auto entryPoint = spirv::lookupEntryPointABI(op)) |
| 147 | return entryPoint.getWorkgroupSize(); |
| 148 | |
| 149 | return {}; |
| 150 | } |
| 151 | |
| 152 | spirv::ResourceLimitsAttr |
| 153 | spirv::getDefaultResourceLimits(MLIRContext *context) { |
| 154 | // All the fields have default values. Here we just provide a nicer way to |
| 155 | // construct a default resource limit attribute. |
| 156 | Builder b(context); |
| 157 | return spirv::ResourceLimitsAttr::get( |
| 158 | context, |
| 159 | /*max_compute_shared_memory_size=*/16384, |
| 160 | /*max_compute_workgroup_invocations=*/128, |
| 161 | /*max_compute_workgroup_size=*/b.getI32ArrayAttr({128, 128, 64}), |
| 162 | /*subgroup_size=*/32, |
| 163 | /*min_subgroup_size=*/std::nullopt, |
| 164 | /*max_subgroup_size=*/std::nullopt, |
| 165 | /*cooperative_matrix_properties_khr=*/ArrayAttr{}, |
| 166 | /*cooperative_matrix_properties_nv=*/ArrayAttr{}); |
| 167 | } |
| 168 | |
| 169 | StringRef spirv::getTargetEnvAttrName() { return "spirv.target_env" ; } |
| 170 | |
| 171 | spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) { |
| 172 | auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0, |
| 173 | {spirv::Capability::Shader}, |
| 174 | ArrayRef<Extension>(), context); |
| 175 | return spirv::TargetEnvAttr::get( |
| 176 | triple, spirv::getDefaultResourceLimits(context), |
| 177 | spirv::ClientAPI::Unknown, spirv::Vendor::Unknown, |
| 178 | spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID); |
| 179 | } |
| 180 | |
| 181 | spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) { |
| 182 | while (op) { |
| 183 | op = SymbolTable::getNearestSymbolTable(from: op); |
| 184 | if (!op) |
| 185 | break; |
| 186 | |
| 187 | if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>( |
| 188 | spirv::getTargetEnvAttrName())) |
| 189 | return attr; |
| 190 | |
| 191 | op = op->getParentOp(); |
| 192 | } |
| 193 | |
| 194 | return {}; |
| 195 | } |
| 196 | |
| 197 | spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) { |
| 198 | if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) |
| 199 | return attr; |
| 200 | |
| 201 | return getDefaultTargetEnv(context: op->getContext()); |
| 202 | } |
| 203 | |
| 204 | spirv::AddressingModel |
| 205 | spirv::getAddressingModel(spirv::TargetEnvAttr targetAttr, |
| 206 | bool use64bitAddress) { |
| 207 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
| 208 | if (cap == Capability::Kernel) |
| 209 | return use64bitAddress ? spirv::AddressingModel::Physical64 |
| 210 | : spirv::AddressingModel::Physical32; |
| 211 | // TODO PhysicalStorageBuffer64 is hard-coded here, but some information |
| 212 | // should come from TargetEnvAttr to select between PhysicalStorageBuffer64 |
| 213 | // and PhysicalStorageBuffer64EXT |
| 214 | if (cap == Capability::PhysicalStorageBufferAddresses) |
| 215 | return spirv::AddressingModel::PhysicalStorageBuffer64; |
| 216 | } |
| 217 | // Logical addressing doesn't need any capabilities so return it as default. |
| 218 | return spirv::AddressingModel::Logical; |
| 219 | } |
| 220 | |
| 221 | FailureOr<spirv::ExecutionModel> |
| 222 | spirv::getExecutionModel(spirv::TargetEnvAttr targetAttr) { |
| 223 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
| 224 | if (cap == spirv::Capability::Kernel) |
| 225 | return spirv::ExecutionModel::Kernel; |
| 226 | if (cap == spirv::Capability::Shader) |
| 227 | return spirv::ExecutionModel::GLCompute; |
| 228 | } |
| 229 | return failure(); |
| 230 | } |
| 231 | |
| 232 | FailureOr<spirv::MemoryModel> |
| 233 | spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) { |
| 234 | for (spirv::Capability cap : targetAttr.getCapabilities()) { |
| 235 | if (cap == spirv::Capability::Kernel) |
| 236 | return spirv::MemoryModel::OpenCL; |
| 237 | if (cap == spirv::Capability::Shader) |
| 238 | return spirv::MemoryModel::GLSL450; |
| 239 | } |
| 240 | return failure(); |
| 241 | } |
| 242 | |