1//===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains corner case tests for the SPIR-V serializer that are not
10// covered by normal serialization and deserialization roundtripping.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Target/SPIRV/Serialization.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/Location.h"
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
24#include "llvm/ADT/DenseSet.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/Sequence.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/ADT/StringRef.h"
29#include "gmock/gmock.h"
30
31using namespace mlir;
32
33//===----------------------------------------------------------------------===//
34// Test Fixture
35//===----------------------------------------------------------------------===//
36
37class SerializationTest : public ::testing::Test {
38protected:
39 SerializationTest() {
40 context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
41 initModuleOp();
42 }
43
44 /// Initializes an empty SPIR-V module op.
45 void initModuleOp() {
46 OpBuilder builder(&context);
47 OperationState state(UnknownLoc::get(&context),
48 spirv::ModuleOp::getOperationName());
49 state.addAttribute("addressing_model",
50 builder.getAttr<spirv::AddressingModelAttr>(
51 spirv::AddressingModel::Logical));
52 state.addAttribute("memory_model", builder.getAttr<spirv::MemoryModelAttr>(
53 spirv::MemoryModel::GLSL450));
54 state.addAttribute("vce_triple",
55 spirv::VerCapExtAttr::get(
56 spirv::Version::V_1_0, ArrayRef<spirv::Capability>(),
57 ArrayRef<spirv::Extension>(), &context));
58 spirv::ModuleOp::build(builder, state);
59 module = cast<spirv::ModuleOp>(Operation::create(state));
60 }
61
62 /// Gets the `struct { float }` type.
63 spirv::StructType getFloatStructType() {
64 OpBuilder builder(module->getRegion());
65 llvm::SmallVector<Type, 1> elementTypes{builder.getF32Type()};
66 llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
67 return spirv::StructType::get(memberTypes: elementTypes, offsetInfo);
68 }
69
70 /// Inserts a global variable of the given `type` and `name`.
71 spirv::GlobalVariableOp addGlobalVar(Type type, llvm::StringRef name) {
72 OpBuilder builder(module->getRegion());
73 auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
74 return builder.create<spirv::GlobalVariableOp>(
75 UnknownLoc::get(&context), TypeAttr::get(ptrType),
76 builder.getStringAttr(name), nullptr);
77 }
78
79 // Inserts an Integer or a Vector of Integers constant of value 'val'.
80 spirv::ConstantOp addConstInt(Type type, const APInt &val) {
81 OpBuilder builder(module->getRegion());
82 auto loc = UnknownLoc::get(&context);
83
84 if (auto intType = dyn_cast<IntegerType>(type)) {
85 return builder.create<spirv::ConstantOp>(
86 loc, type, builder.getIntegerAttr(type, val));
87 }
88 if (auto vectorType = dyn_cast<VectorType>(type)) {
89 Type elemType = vectorType.getElementType();
90 if (auto intType = dyn_cast<IntegerType>(elemType)) {
91 return builder.create<spirv::ConstantOp>(
92 loc, type,
93 DenseElementsAttr::get(vectorType,
94 IntegerAttr::get(elemType, val).getValue()));
95 }
96 }
97 llvm_unreachable("unimplemented types for AddConstInt()");
98 }
99
100 /// Handles a SPIR-V instruction with the given `opcode` and `operand`.
101 /// Returns true to interrupt.
102 using HandleFn = llvm::function_ref<bool(spirv::Opcode opcode,
103 ArrayRef<uint32_t> operands)>;
104
105 /// Returns true if we can find a matching instruction in the SPIR-V blob.
106 bool scanInstruction(HandleFn handleFn) {
107 auto binarySize = binary.size();
108 auto *begin = binary.begin();
109 auto currOffset = spirv::kHeaderWordCount;
110
111 while (currOffset < binarySize) {
112 auto wordCount = binary[currOffset] >> 16;
113 if (!wordCount || (currOffset + wordCount > binarySize))
114 return false;
115
116 spirv::Opcode opcode =
117 static_cast<spirv::Opcode>(binary[currOffset] & 0xffff);
118 llvm::ArrayRef<uint32_t> operands(begin + currOffset + 1,
119 begin + currOffset + wordCount);
120 if (handleFn(opcode, operands))
121 return true;
122
123 currOffset += wordCount;
124 }
125 return false;
126 }
127
128protected:
129 MLIRContext context;
130 OwningOpRef<spirv::ModuleOp> module;
131 SmallVector<uint32_t, 0> binary;
132};
133
134//===----------------------------------------------------------------------===//
135// Block decoration
136//===----------------------------------------------------------------------===//
137
138TEST_F(SerializationTest, ContainsBlockDecoration) {
139 auto structType = getFloatStructType();
140 addGlobalVar(structType, "var0");
141
142 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
143
144 auto hasBlockDecoration = [](spirv::Opcode opcode,
145 ArrayRef<uint32_t> operands) {
146 return opcode == spirv::Opcode::OpDecorate && operands.size() == 2 &&
147 operands[1] == static_cast<uint32_t>(spirv::Decoration::Block);
148 };
149 EXPECT_TRUE(scanInstruction(hasBlockDecoration));
150}
151
152TEST_F(SerializationTest, ContainsNoDuplicatedBlockDecoration) {
153 auto structType = getFloatStructType();
154 // Two global variables using the same type should not decorate the type with
155 // duplicated `Block` decorations.
156 addGlobalVar(structType, "var0");
157 addGlobalVar(structType, "var1");
158
159 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
160
161 unsigned count = 0;
162 auto countBlockDecoration = [&count](spirv::Opcode opcode,
163 ArrayRef<uint32_t> operands) {
164 if (opcode == spirv::Opcode::OpDecorate && operands.size() == 2 &&
165 operands[1] == static_cast<uint32_t>(spirv::Decoration::Block))
166 ++count;
167 return false;
168 };
169 ASSERT_FALSE(scanInstruction(countBlockDecoration));
170 EXPECT_EQ(count, 1u);
171}
172
173TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) {
174
175 auto signlessInt16Type =
176 IntegerType::get(&context, 16, IntegerType::Signless);
177 auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed);
178 // Check the bit extension of same value under different signedness semantics.
179 APInt signlessIntConstVal(signlessInt16Type.getWidth(), -1,
180 signlessInt16Type.getSignedness());
181 APInt signedIntConstVal(signedInt16Type.getWidth(), -1,
182 signedInt16Type.getSignedness());
183
184 addConstInt(signlessInt16Type, signlessIntConstVal);
185 addConstInt(signedInt16Type, signedIntConstVal);
186 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
187
188 auto hasSignlessVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
189 return opcode == spirv::Opcode::OpConstant && operands.size() == 3 &&
190 operands[2] == 65535;
191 };
192 EXPECT_TRUE(scanInstruction(hasSignlessVal));
193
194 auto hasSignedVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
195 return opcode == spirv::Opcode::OpConstant && operands.size() == 3 &&
196 operands[2] == 4294967295;
197 };
198 EXPECT_TRUE(scanInstruction(hasSignedVal));
199}
200
201TEST_F(SerializationTest, ContainsSymbolName) {
202 auto structType = getFloatStructType();
203 addGlobalVar(structType, "var0");
204
205 spirv::SerializationOptions options;
206 options.emitSymbolName = true;
207 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
208
209 auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
210 unsigned index = 1; // Skip the result <id>
211 return opcode == spirv::Opcode::OpName &&
212 spirv::decodeStringLiteral(operands, index) == "var0";
213 };
214 EXPECT_TRUE(scanInstruction(hasVarName));
215}
216
217TEST_F(SerializationTest, DoesNotContainSymbolName) {
218 auto structType = getFloatStructType();
219 addGlobalVar(structType, "var0");
220
221 spirv::SerializationOptions options;
222 options.emitSymbolName = false;
223 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
224
225 auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
226 unsigned index = 1; // Skip the result <id>
227 return opcode == spirv::Opcode::OpName &&
228 spirv::decodeStringLiteral(operands, index) == "var0";
229 };
230 EXPECT_FALSE(scanInstruction(hasVarName));
231}
232

source code of mlir/unittests/Dialect/SPIRV/SerializationTest.cpp