1//===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to deduce minimal version/extension/capability
10// requirements for a spirv::ModuleOp.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
15
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/Visitors.h"
22#include "llvm/ADT/SetVector.h"
23#include "llvm/ADT/SmallSet.h"
24#include "llvm/ADT/StringExtras.h"
25#include <optional>
26
27namespace mlir {
28namespace spirv {
29#define GEN_PASS_DEF_SPIRVUPDATEVCEPASS
30#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
31} // namespace spirv
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37/// Pass to deduce minimal version/extension/capability requirements for a
38/// spirv::ModuleOp.
39class UpdateVCEPass final
40 : public spirv::impl::SPIRVUpdateVCEPassBase<UpdateVCEPass> {
41 void runOnOperation() override;
42};
43} // namespace
44
45/// Checks that `candidates` extension requirements are possible to be satisfied
46/// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
47/// errors attaching to the given `op` on failures.
48///
49/// `candidates` is a vector of vector for extension requirements following
50/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
51/// convention.
52static LogicalResult checkAndUpdateExtensionRequirements(
53 Operation *op, const spirv::TargetEnv &targetEnv,
54 const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
55 SetVector<spirv::Extension> &deducedExtensions) {
56 for (const auto &ors : candidates) {
57 if (std::optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
58 deducedExtensions.insert(X: *chosen);
59 } else {
60 SmallVector<StringRef, 4> extStrings;
61 for (spirv::Extension ext : ors)
62 extStrings.push_back(spirv::stringifyExtension(ext));
63
64 return op->emitError(message: "'")
65 << op->getName() << "' requires at least one extension in ["
66 << llvm::join(R&: extStrings, Separator: ", ")
67 << "] but none allowed in target environment";
68 }
69 }
70 return success();
71}
72
73/// Checks that `candidates`capability requirements are possible to be satisfied
74/// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
75/// errors attaching to the given `op` on failures.
76///
77/// `candidates` is a vector of vector for capability requirements following
78/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
79/// convention.
80static LogicalResult checkAndUpdateCapabilityRequirements(
81 Operation *op, const spirv::TargetEnv &targetEnv,
82 const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
83 SetVector<spirv::Capability> &deducedCapabilities) {
84 for (const auto &ors : candidates) {
85 if (std::optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
86 deducedCapabilities.insert(X: *chosen);
87 } else {
88 SmallVector<StringRef, 4> capStrings;
89 for (spirv::Capability cap : ors)
90 capStrings.push_back(spirv::stringifyCapability(cap));
91
92 return op->emitError(message: "'")
93 << op->getName() << "' requires at least one capability in ["
94 << llvm::join(R&: capStrings, Separator: ", ")
95 << "] but none allowed in target environment";
96 }
97 }
98 return success();
99}
100
101void UpdateVCEPass::runOnOperation() {
102 spirv::ModuleOp module = getOperation();
103
104 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(op: module);
105 if (!targetAttr) {
106 module.emitError("missing 'spirv.target_env' attribute");
107 return signalPassFailure();
108 }
109
110 spirv::TargetEnv targetEnv(targetAttr);
111 spirv::Version allowedVersion = targetAttr.getVersion();
112
113 spirv::Version deducedVersion = spirv::Version::V_1_0;
114 SetVector<spirv::Extension> deducedExtensions;
115 SetVector<spirv::Capability> deducedCapabilities;
116
117 // Walk each SPIR-V op to deduce the minimal version/extension/capability
118 // requirements.
119 WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
120 // Op min version requirements
121 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
122 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
123 if (minVersion) {
124 deducedVersion = std::max(a: deducedVersion, b: *minVersion);
125 if (deducedVersion > allowedVersion) {
126 return op->emitError("'")
127 << op->getName() << "' requires min version "
128 << spirv::stringifyVersion(deducedVersion)
129 << " but target environment allows up to "
130 << spirv::stringifyVersion(allowedVersion);
131 }
132 }
133 }
134
135 // Op extension requirements
136 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
137 if (failed(checkAndUpdateExtensionRequirements(
138 op, targetEnv, extensions.getExtensions(), deducedExtensions)))
139 return WalkResult::interrupt();
140
141 // Op capability requirements
142 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
143 if (failed(checkAndUpdateCapabilityRequirements(
144 op, targetEnv, capabilities.getCapabilities(),
145 deducedCapabilities)))
146 return WalkResult::interrupt();
147
148 SmallVector<Type, 4> valueTypes;
149 valueTypes.append(in_start: op->operand_type_begin(), in_end: op->operand_type_end());
150 valueTypes.append(in_start: op->result_type_begin(), in_end: op->result_type_end());
151
152 // Special treatment for global variables, whose type requirements are
153 // conveyed by type attributes.
154 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
155 valueTypes.push_back(Elt: globalVar.getType());
156
157 // Requirements from values' types
158 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
159 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
160 for (Type valueType : valueTypes) {
161 typeExtensions.clear();
162 cast<spirv::SPIRVType>(Val&: valueType).getExtensions(typeExtensions);
163 if (failed(result: checkAndUpdateExtensionRequirements(
164 op, targetEnv, candidates: typeExtensions, deducedExtensions)))
165 return WalkResult::interrupt();
166
167 typeCapabilities.clear();
168 cast<spirv::SPIRVType>(Val&: valueType).getCapabilities(typeCapabilities);
169 if (failed(result: checkAndUpdateCapabilityRequirements(
170 op, targetEnv, candidates: typeCapabilities, deducedCapabilities)))
171 return WalkResult::interrupt();
172 }
173
174 return WalkResult::advance();
175 });
176
177 if (walkResult.wasInterrupted())
178 return signalPassFailure();
179
180 // TODO: verify that the deduced version is consistent with
181 // SPIR-V ops' maximal version requirements.
182
183 auto triple = spirv::VerCapExtAttr::get(
184 version: deducedVersion, capabilities: deducedCapabilities.getArrayRef(),
185 extensions: deducedExtensions.getArrayRef(), context: &getContext());
186 module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
187}
188

source code of mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp