1//===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to deduce minimal version/extension/capability
10// requirements for a spirv::ModuleOp.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
15
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/Visitors.h"
21#include "llvm/ADT/StringExtras.h"
22#include <optional>
23
24namespace mlir {
25namespace spirv {
26#define GEN_PASS_DEF_SPIRVUPDATEVCEPASS
27#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
28} // namespace spirv
29} // namespace mlir
30
31using namespace mlir;
32
33namespace {
34/// Pass to deduce minimal version/extension/capability requirements for a
35/// spirv::ModuleOp.
36class UpdateVCEPass final
37 : public spirv::impl::SPIRVUpdateVCEPassBase<UpdateVCEPass> {
38 void runOnOperation() override;
39};
40} // namespace
41
42/// Checks that `candidates` extension requirements are possible to be satisfied
43/// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
44/// errors attaching to the given `op` on failures.
45///
46/// `candidates` is a vector of vector for extension requirements following
47/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
48/// convention.
49static LogicalResult checkAndUpdateExtensionRequirements(
50 Operation *op, const spirv::TargetEnv &targetEnv,
51 const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
52 SetVector<spirv::Extension> &deducedExtensions) {
53 for (const auto &ors : candidates) {
54 if (std::optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
55 deducedExtensions.insert(X: *chosen);
56 } else {
57 SmallVector<StringRef, 4> extStrings;
58 for (spirv::Extension ext : ors)
59 extStrings.push_back(Elt: spirv::stringifyExtension(ext));
60
61 return op->emitError(message: "'")
62 << op->getName() << "' requires at least one extension in ["
63 << llvm::join(R&: extStrings, Separator: ", ")
64 << "] but none allowed in target environment";
65 }
66 }
67 return success();
68}
69
70/// Checks that `candidates`capability requirements are possible to be satisfied
71/// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
72/// errors attaching to the given `op` on failures.
73///
74/// `candidates` is a vector of vector for capability requirements following
75/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
76/// convention.
77static LogicalResult checkAndUpdateCapabilityRequirements(
78 Operation *op, const spirv::TargetEnv &targetEnv,
79 const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
80 SetVector<spirv::Capability> &deducedCapabilities) {
81 for (const auto &ors : candidates) {
82 if (std::optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
83 deducedCapabilities.insert(X: *chosen);
84 } else {
85 SmallVector<StringRef, 4> capStrings;
86 for (spirv::Capability cap : ors)
87 capStrings.push_back(Elt: spirv::stringifyCapability(cap));
88
89 return op->emitError(message: "'")
90 << op->getName() << "' requires at least one capability in ["
91 << llvm::join(R&: capStrings, Separator: ", ")
92 << "] but none allowed in target environment";
93 }
94 }
95 return success();
96}
97
98void UpdateVCEPass::runOnOperation() {
99 spirv::ModuleOp module = getOperation();
100
101 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(op: module);
102 if (!targetAttr) {
103 module.emitError(message: "missing 'spirv.target_env' attribute");
104 return signalPassFailure();
105 }
106
107 spirv::TargetEnv targetEnv(targetAttr);
108 spirv::Version allowedVersion = targetAttr.getVersion();
109
110 spirv::Version deducedVersion = spirv::Version::V_1_0;
111 SetVector<spirv::Extension> deducedExtensions;
112 SetVector<spirv::Capability> deducedCapabilities;
113
114 // Walk each SPIR-V op to deduce the minimal version/extension/capability
115 // requirements.
116 WalkResult walkResult = module.walk(callback: [&](Operation *op) -> WalkResult {
117 // Op min version requirements
118 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(Val: op)) {
119 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
120 if (minVersion) {
121 deducedVersion = std::max(a: deducedVersion, b: *minVersion);
122 if (deducedVersion > allowedVersion) {
123 return op->emitError(message: "'")
124 << op->getName() << "' requires min version "
125 << spirv::stringifyVersion(deducedVersion)
126 << " but target environment allows up to "
127 << spirv::stringifyVersion(allowedVersion);
128 }
129 }
130 }
131
132 // Op extension requirements
133 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(Val: op))
134 if (failed(Result: checkAndUpdateExtensionRequirements(
135 op, targetEnv, candidates: extensions.getExtensions(), deducedExtensions)))
136 return WalkResult::interrupt();
137
138 // Op capability requirements
139 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(Val: op))
140 if (failed(Result: checkAndUpdateCapabilityRequirements(
141 op, targetEnv, candidates: capabilities.getCapabilities(),
142 deducedCapabilities)))
143 return WalkResult::interrupt();
144
145 SmallVector<Type, 4> valueTypes;
146 valueTypes.append(in_start: op->operand_type_begin(), in_end: op->operand_type_end());
147 valueTypes.append(in_start: op->result_type_begin(), in_end: op->result_type_end());
148
149 // Special treatment for global variables, whose type requirements are
150 // conveyed by type attributes.
151 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(Val: op))
152 valueTypes.push_back(Elt: globalVar.getType());
153
154 // Requirements from values' types
155 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
156 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
157 for (Type valueType : valueTypes) {
158 typeExtensions.clear();
159 cast<spirv::SPIRVType>(Val&: valueType).getExtensions(extensions&: typeExtensions);
160 if (failed(Result: checkAndUpdateExtensionRequirements(
161 op, targetEnv, candidates: typeExtensions, deducedExtensions)))
162 return WalkResult::interrupt();
163
164 typeCapabilities.clear();
165 cast<spirv::SPIRVType>(Val&: valueType).getCapabilities(capabilities&: typeCapabilities);
166 if (failed(Result: checkAndUpdateCapabilityRequirements(
167 op, targetEnv, candidates: typeCapabilities, deducedCapabilities)))
168 return WalkResult::interrupt();
169 }
170
171 return WalkResult::advance();
172 });
173
174 if (walkResult.wasInterrupted())
175 return signalPassFailure();
176
177 // TODO: verify that the deduced version is consistent with
178 // SPIR-V ops' maximal version requirements.
179
180 auto triple = spirv::VerCapExtAttr::get(
181 version: deducedVersion, capabilities: deducedCapabilities.getArrayRef(),
182 extensions: deducedExtensions.getArrayRef(), context: &getContext());
183 module->setAttr(name: spirv::ModuleOp::getVCETripleAttrName(), value: triple);
184}
185

source code of mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp