1//===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
11#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
12#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
13#include "mlir/Pass/Pass.h"
14
15using namespace mlir;
16
17//===----------------------------------------------------------------------===//
18// Printing op availability pass
19//===----------------------------------------------------------------------===//
20
21namespace {
22/// A pass for testing SPIR-V op availability.
23struct PrintOpAvailability
24 : public PassWrapper<PrintOpAvailability, OperationPass<func::FuncOp>> {
25 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability)
26
27 void runOnOperation() override;
28 StringRef getArgument() const final { return "test-spirv-op-availability"; }
29 StringRef getDescription() const final {
30 return "Test SPIR-V op availability";
31 }
32};
33} // namespace
34
35void PrintOpAvailability::runOnOperation() {
36 auto f = getOperation();
37 llvm::outs() << f.getName() << "\n";
38
39 Dialect *spirvDialect = getContext().getLoadedDialect("spirv");
40
41 f->walk([&](Operation *op) {
42 if (op->getDialect() != spirvDialect)
43 return WalkResult::advance();
44
45 auto opName = op->getName();
46 auto &os = llvm::outs();
47
48 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
49 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
50 os << opName << " min version: ";
51 if (minVersion)
52 os << spirv::stringifyVersion(*minVersion) << "\n";
53 else
54 os << "None\n";
55 }
56
57 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
58 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
59 os << opName << " max version: ";
60 if (maxVersion)
61 os << spirv::stringifyVersion(*maxVersion) << "\n";
62 else
63 os << "None\n";
64 }
65
66 if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
67 os << opName << " extensions: [";
68 for (const auto &exts : extension.getExtensions()) {
69 os << " [";
70 llvm::interleaveComma(exts, os, [&](spirv::Extension ext) {
71 os << spirv::stringifyExtension(ext);
72 });
73 os << "]";
74 }
75 os << " ]\n";
76 }
77
78 if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
79 os << opName << " capabilities: [";
80 for (const auto &caps : capability.getCapabilities()) {
81 os << " [";
82 llvm::interleaveComma(caps, os, [&](spirv::Capability cap) {
83 os << spirv::stringifyCapability(cap);
84 });
85 os << "]";
86 }
87 os << " ]\n";
88 }
89 os.flush();
90
91 return WalkResult::advance();
92 });
93}
94
95namespace mlir {
96void registerPrintSpirvAvailabilityPass() {
97 PassRegistration<PrintOpAvailability>();
98}
99} // namespace mlir
100
101//===----------------------------------------------------------------------===//
102// Converting target environment pass
103//===----------------------------------------------------------------------===//
104
105namespace {
106/// A pass for testing SPIR-V op availability.
107struct ConvertToTargetEnv
108 : public PassWrapper<ConvertToTargetEnv, OperationPass<func::FuncOp>> {
109 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertToTargetEnv)
110
111 StringRef getArgument() const override { return "test-spirv-target-env"; }
112 StringRef getDescription() const override {
113 return "Test SPIR-V target environment";
114 }
115 void runOnOperation() override;
116};
117
118struct ConvertToAtomCmpExchangeWeak : RewritePattern {
119 ConvertToAtomCmpExchangeWeak(MLIRContext *context)
120 : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1,
121 context, {"spirv.AtomicCompareExchangeWeak"}) {}
122
123 LogicalResult matchAndRewrite(Operation *op,
124 PatternRewriter &rewriter) const override {
125 Value ptr = op->getOperand(idx: 0);
126 Value value = op->getOperand(idx: 1);
127 Value comparator = op->getOperand(idx: 2);
128
129 // Create a spirv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits
130 // in memory semantics to additionally require AtomicStorage capability.
131 rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>(
132 op, value.getType(), ptr, spirv::Scope::Workgroup,
133 spirv::MemorySemantics::AcquireRelease |
134 spirv::MemorySemantics::AtomicCounterMemory,
135 spirv::MemorySemantics::Acquire, value, comparator);
136 return success();
137 }
138};
139
140struct ConvertToBitReverse : RewritePattern {
141 ConvertToBitReverse(MLIRContext *context)
142 : RewritePattern("test.convert_to_bit_reverse_op", 1, context,
143 {"spirv.BitReverse"}) {}
144
145 LogicalResult matchAndRewrite(Operation *op,
146 PatternRewriter &rewriter) const override {
147 Value predicate = op->getOperand(idx: 0);
148 rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
149 op, op->getResult(0).getType(), predicate);
150 return success();
151 }
152};
153
154struct ConvertToGroupNonUniformBallot : RewritePattern {
155 ConvertToGroupNonUniformBallot(MLIRContext *context)
156 : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1,
157 context, {"spirv.GroupNonUniformBallot"}) {}
158
159 LogicalResult matchAndRewrite(Operation *op,
160 PatternRewriter &rewriter) const override {
161 Value predicate = op->getOperand(idx: 0);
162 rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
163 op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
164 return success();
165 }
166};
167
168struct ConvertToModule : RewritePattern {
169 ConvertToModule(MLIRContext *context)
170 : RewritePattern("test.convert_to_module_op", 1, context,
171 {"spirv.module"}) {}
172
173 LogicalResult matchAndRewrite(Operation *op,
174 PatternRewriter &rewriter) const override {
175 rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
176 op, spirv::AddressingModel::PhysicalStorageBuffer64,
177 spirv::MemoryModel::Vulkan);
178 return success();
179 }
180};
181
182struct ConvertToSubgroupBallot : RewritePattern {
183 ConvertToSubgroupBallot(MLIRContext *context)
184 : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context,
185 {"spirv.KHR.SubgroupBallot"}) {}
186
187 LogicalResult matchAndRewrite(Operation *op,
188 PatternRewriter &rewriter) const override {
189 Value predicate = op->getOperand(idx: 0);
190 rewriter.replaceOpWithNewOp<spirv::KHRSubgroupBallotOp>(
191 op, op->getResult(0).getType(), predicate);
192 return success();
193 }
194};
195
196template <const char *TestOpName, typename SPIRVOp>
197struct ConvertToIntegerDotProd : RewritePattern {
198 ConvertToIntegerDotProd(MLIRContext *context)
199 : RewritePattern(TestOpName, 1, context, {SPIRVOp::getOperationName()}) {}
200
201 LogicalResult matchAndRewrite(Operation *op,
202 PatternRewriter &rewriter) const override {
203 rewriter.replaceOpWithNewOp<SPIRVOp>(op, op->getResultTypes(),
204 op->getOperands(), op->getAttrs());
205 return success();
206 }
207};
208} // namespace
209
210void ConvertToTargetEnv::runOnOperation() {
211 MLIRContext *context = &getContext();
212 func::FuncOp fn = getOperation();
213
214 auto targetEnv = dyn_cast_or_null<spirv::TargetEnvAttr>(
215 fn.getOperation()->getDiscardableAttr(spirv::getTargetEnvAttrName()));
216 if (!targetEnv) {
217 fn.emitError("missing 'spirv.target_env' attribute");
218 return signalPassFailure();
219 }
220
221 auto target = SPIRVConversionTarget::get(targetAttr: targetEnv);
222
223 static constexpr char sDotTestOpName[] = "test.convert_to_sdot_op";
224 static constexpr char suDotTestOpName[] = "test.convert_to_sudot_op";
225 static constexpr char uDotTestOpName[] = "test.convert_to_udot_op";
226 static constexpr char sDotAccSatTestOpName[] =
227 "test.convert_to_sdot_acc_sat_op";
228 static constexpr char suDotAccSatTestOpName[] =
229 "test.convert_to_sudot_acc_sat_op";
230 static constexpr char uDotAccSatTestOpName[] =
231 "test.convert_to_udot_acc_sat_op";
232
233 RewritePatternSet patterns(context);
234 patterns.add<
235 ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
236 ConvertToGroupNonUniformBallot, ConvertToModule, ConvertToSubgroupBallot,
237 ConvertToIntegerDotProd<sDotTestOpName, spirv::SDotOp>,
238 ConvertToIntegerDotProd<suDotTestOpName, spirv::SUDotOp>,
239 ConvertToIntegerDotProd<uDotTestOpName, spirv::UDotOp>,
240 ConvertToIntegerDotProd<sDotAccSatTestOpName, spirv::SDotAccSatOp>,
241 ConvertToIntegerDotProd<suDotAccSatTestOpName, spirv::SUDotAccSatOp>,
242 ConvertToIntegerDotProd<uDotAccSatTestOpName, spirv::UDotAccSatOp>>(
243 context);
244
245 if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
246 return signalPassFailure();
247}
248
249namespace mlir {
250void registerConvertToTargetEnvPass() {
251 PassRegistration<ConvertToTargetEnv>();
252}
253} // namespace mlir
254

source code of mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp