1 | //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to deduce minimal version/extension/capability |
10 | // requirements for a spirv::ModuleOp. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/SPIRV/Transforms/Passes.h" |
15 | |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
19 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
20 | #include "mlir/IR/Builders.h" |
21 | #include "mlir/IR/Visitors.h" |
22 | #include "llvm/ADT/SetVector.h" |
23 | #include "llvm/ADT/SmallSet.h" |
24 | #include "llvm/ADT/StringExtras.h" |
25 | #include <optional> |
26 | |
27 | namespace mlir { |
28 | namespace spirv { |
29 | #define GEN_PASS_DEF_SPIRVUPDATEVCEPASS |
30 | #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc" |
31 | } // namespace spirv |
32 | } // namespace mlir |
33 | |
34 | using namespace mlir; |
35 | |
36 | namespace { |
37 | /// Pass to deduce minimal version/extension/capability requirements for a |
38 | /// spirv::ModuleOp. |
39 | class UpdateVCEPass final |
40 | : public spirv::impl::SPIRVUpdateVCEPassBase<UpdateVCEPass> { |
41 | void runOnOperation() override; |
42 | }; |
43 | } // namespace |
44 | |
45 | /// Checks that `candidates` extension requirements are possible to be satisfied |
46 | /// with the given `targetEnv` and updates `deducedExtensions` if so. Emits |
47 | /// errors attaching to the given `op` on failures. |
48 | /// |
49 | /// `candidates` is a vector of vector for extension requirements following |
50 | /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) |
51 | /// convention. |
52 | static LogicalResult checkAndUpdateExtensionRequirements( |
53 | Operation *op, const spirv::TargetEnv &targetEnv, |
54 | const spirv::SPIRVType::ExtensionArrayRefVector &candidates, |
55 | SetVector<spirv::Extension> &deducedExtensions) { |
56 | for (const auto &ors : candidates) { |
57 | if (std::optional<spirv::Extension> chosen = targetEnv.allows(ors)) { |
58 | deducedExtensions.insert(X: *chosen); |
59 | } else { |
60 | SmallVector<StringRef, 4> extStrings; |
61 | for (spirv::Extension ext : ors) |
62 | extStrings.push_back(spirv::stringifyExtension(ext)); |
63 | |
64 | return op->emitError(message: "'" ) |
65 | << op->getName() << "' requires at least one extension in [" |
66 | << llvm::join(R&: extStrings, Separator: ", " ) |
67 | << "] but none allowed in target environment" ; |
68 | } |
69 | } |
70 | return success(); |
71 | } |
72 | |
73 | /// Checks that `candidates`capability requirements are possible to be satisfied |
74 | /// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits |
75 | /// errors attaching to the given `op` on failures. |
76 | /// |
77 | /// `candidates` is a vector of vector for capability requirements following |
78 | /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) |
79 | /// convention. |
80 | static LogicalResult checkAndUpdateCapabilityRequirements( |
81 | Operation *op, const spirv::TargetEnv &targetEnv, |
82 | const spirv::SPIRVType::CapabilityArrayRefVector &candidates, |
83 | SetVector<spirv::Capability> &deducedCapabilities) { |
84 | for (const auto &ors : candidates) { |
85 | if (std::optional<spirv::Capability> chosen = targetEnv.allows(ors)) { |
86 | deducedCapabilities.insert(X: *chosen); |
87 | } else { |
88 | SmallVector<StringRef, 4> capStrings; |
89 | for (spirv::Capability cap : ors) |
90 | capStrings.push_back(spirv::stringifyCapability(cap)); |
91 | |
92 | return op->emitError(message: "'" ) |
93 | << op->getName() << "' requires at least one capability in [" |
94 | << llvm::join(R&: capStrings, Separator: ", " ) |
95 | << "] but none allowed in target environment" ; |
96 | } |
97 | } |
98 | return success(); |
99 | } |
100 | |
101 | void UpdateVCEPass::runOnOperation() { |
102 | spirv::ModuleOp module = getOperation(); |
103 | |
104 | spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(op: module); |
105 | if (!targetAttr) { |
106 | module.emitError("missing 'spirv.target_env' attribute" ); |
107 | return signalPassFailure(); |
108 | } |
109 | |
110 | spirv::TargetEnv targetEnv(targetAttr); |
111 | spirv::Version allowedVersion = targetAttr.getVersion(); |
112 | |
113 | spirv::Version deducedVersion = spirv::Version::V_1_0; |
114 | SetVector<spirv::Extension> deducedExtensions; |
115 | SetVector<spirv::Capability> deducedCapabilities; |
116 | |
117 | // Walk each SPIR-V op to deduce the minimal version/extension/capability |
118 | // requirements. |
119 | WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult { |
120 | // Op min version requirements |
121 | if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) { |
122 | std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion(); |
123 | if (minVersion) { |
124 | deducedVersion = std::max(a: deducedVersion, b: *minVersion); |
125 | if (deducedVersion > allowedVersion) { |
126 | return op->emitError("'" ) |
127 | << op->getName() << "' requires min version " |
128 | << spirv::stringifyVersion(deducedVersion) |
129 | << " but target environment allows up to " |
130 | << spirv::stringifyVersion(allowedVersion); |
131 | } |
132 | } |
133 | } |
134 | |
135 | // Op extension requirements |
136 | if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) |
137 | if (failed(checkAndUpdateExtensionRequirements( |
138 | op, targetEnv, extensions.getExtensions(), deducedExtensions))) |
139 | return WalkResult::interrupt(); |
140 | |
141 | // Op capability requirements |
142 | if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) |
143 | if (failed(checkAndUpdateCapabilityRequirements( |
144 | op, targetEnv, capabilities.getCapabilities(), |
145 | deducedCapabilities))) |
146 | return WalkResult::interrupt(); |
147 | |
148 | SmallVector<Type, 4> valueTypes; |
149 | valueTypes.append(in_start: op->operand_type_begin(), in_end: op->operand_type_end()); |
150 | valueTypes.append(in_start: op->result_type_begin(), in_end: op->result_type_end()); |
151 | |
152 | // Special treatment for global variables, whose type requirements are |
153 | // conveyed by type attributes. |
154 | if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) |
155 | valueTypes.push_back(Elt: globalVar.getType()); |
156 | |
157 | // Requirements from values' types |
158 | SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; |
159 | SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; |
160 | for (Type valueType : valueTypes) { |
161 | typeExtensions.clear(); |
162 | cast<spirv::SPIRVType>(Val&: valueType).getExtensions(typeExtensions); |
163 | if (failed(result: checkAndUpdateExtensionRequirements( |
164 | op, targetEnv, candidates: typeExtensions, deducedExtensions))) |
165 | return WalkResult::interrupt(); |
166 | |
167 | typeCapabilities.clear(); |
168 | cast<spirv::SPIRVType>(Val&: valueType).getCapabilities(typeCapabilities); |
169 | if (failed(result: checkAndUpdateCapabilityRequirements( |
170 | op, targetEnv, candidates: typeCapabilities, deducedCapabilities))) |
171 | return WalkResult::interrupt(); |
172 | } |
173 | |
174 | return WalkResult::advance(); |
175 | }); |
176 | |
177 | if (walkResult.wasInterrupted()) |
178 | return signalPassFailure(); |
179 | |
180 | // TODO: verify that the deduced version is consistent with |
181 | // SPIR-V ops' maximal version requirements. |
182 | |
183 | auto triple = spirv::VerCapExtAttr::get( |
184 | version: deducedVersion, capabilities: deducedCapabilities.getArrayRef(), |
185 | extensions: deducedExtensions.getArrayRef(), context: &getContext()); |
186 | module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple); |
187 | } |
188 | |