1 | //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // The purpose of this file is to provide negative deserialization tests. |
10 | // For positive deserialization tests, please use serialization and |
11 | // deserialization for roundtripping. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Target/SPIRV/Deserialization.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
18 | #include "mlir/IR/Diagnostics.h" |
19 | #include "mlir/IR/MLIRContext.h" |
20 | #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" |
21 | #include "gmock/gmock.h" |
22 | |
23 | #include <memory> |
24 | |
25 | using namespace mlir; |
26 | |
27 | using ::testing::StrEq; |
28 | |
29 | //===----------------------------------------------------------------------===// |
30 | // Test Fixture |
31 | //===----------------------------------------------------------------------===// |
32 | |
33 | /// A deserialization test fixture providing minimal SPIR-V building and |
34 | /// diagnostic checking utilities. |
35 | class DeserializationTest : public ::testing::Test { |
36 | protected: |
37 | DeserializationTest() { |
38 | context.getOrLoadDialect<mlir::spirv::SPIRVDialect>(); |
39 | // Register a diagnostic handler to capture the diagnostic so that we can |
40 | // check it later. |
41 | context.getDiagEngine().registerHandler(handler: [&](Diagnostic &diag) { |
42 | diagnostic = std::make_unique<Diagnostic>(args: std::move(diag)); |
43 | }); |
44 | } |
45 | |
46 | /// Performs deserialization and returns the constructed spirv.module op. |
47 | OwningOpRef<spirv::ModuleOp> deserialize() { |
48 | return spirv::deserialize(binary, context: &context); |
49 | } |
50 | |
51 | /// Checks there is a diagnostic generated with the given `errorMessage`. |
52 | void expectDiagnostic(StringRef errorMessage) { |
53 | ASSERT_NE(nullptr, diagnostic.get()); |
54 | |
55 | // TODO: check error location too. |
56 | EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage))); |
57 | } |
58 | |
59 | //===--------------------------------------------------------------------===// |
60 | // SPIR-V builder methods |
61 | //===--------------------------------------------------------------------===// |
62 | |
63 | /// Adds the SPIR-V module header to `binary`. |
64 | void () { |
65 | spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0); |
66 | } |
67 | |
68 | /// Adds the SPIR-V instruction into `binary`. |
69 | void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) { |
70 | uint32_t wordCount = 1 + operands.size(); |
71 | binary.push_back(spirv::Elt: getPrefixedOpcode(wordCount, op)); |
72 | binary.append(in_start: operands.begin(), in_end: operands.end()); |
73 | } |
74 | |
75 | uint32_t addVoidType() { |
76 | auto id = nextID++; |
77 | addInstruction(spirv::Opcode::OpTypeVoid, {id}); |
78 | return id; |
79 | } |
80 | |
81 | uint32_t addIntType(uint32_t bitwidth) { |
82 | auto id = nextID++; |
83 | addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1}); |
84 | return id; |
85 | } |
86 | |
87 | uint32_t addStructType(ArrayRef<uint32_t> memberTypes) { |
88 | auto id = nextID++; |
89 | SmallVector<uint32_t, 2> words; |
90 | words.push_back(Elt: id); |
91 | words.append(in_start: memberTypes.begin(), in_end: memberTypes.end()); |
92 | addInstruction(spirv::Opcode::OpTypeStruct, words); |
93 | return id; |
94 | } |
95 | |
96 | uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) { |
97 | auto id = nextID++; |
98 | SmallVector<uint32_t, 4> operands; |
99 | operands.push_back(Elt: id); |
100 | operands.push_back(Elt: retType); |
101 | operands.append(in_start: paramTypes.begin(), in_end: paramTypes.end()); |
102 | addInstruction(spirv::Opcode::OpTypeFunction, operands); |
103 | return id; |
104 | } |
105 | |
106 | uint32_t addFunction(uint32_t retType, uint32_t fnType) { |
107 | auto id = nextID++; |
108 | addInstruction(spirv::Opcode::OpFunction, |
109 | {retType, id, |
110 | static_cast<uint32_t>(spirv::FunctionControl::None), |
111 | fnType}); |
112 | return id; |
113 | } |
114 | |
115 | void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); } |
116 | |
117 | void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); } |
118 | |
119 | protected: |
120 | SmallVector<uint32_t, 5> binary; |
121 | uint32_t nextID = 1; |
122 | MLIRContext context; |
123 | std::unique_ptr<Diagnostic> diagnostic; |
124 | }; |
125 | |
126 | //===----------------------------------------------------------------------===// |
127 | // Basics |
128 | //===----------------------------------------------------------------------===// |
129 | |
130 | TEST_F(DeserializationTest, EmptyModuleFailure) { |
131 | ASSERT_FALSE(deserialize()); |
132 | expectDiagnostic(errorMessage: "SPIR-V binary module must have a 5-word header" ); |
133 | } |
134 | |
135 | TEST_F(DeserializationTest, WrongMagicNumberFailure) { |
136 | addHeader(); |
137 | binary.front() = 0xdeadbeef; // Change to a wrong magic number |
138 | ASSERT_FALSE(deserialize()); |
139 | expectDiagnostic(errorMessage: "incorrect magic number" ); |
140 | } |
141 | |
142 | TEST_F(DeserializationTest, OnlyHeaderSuccess) { |
143 | addHeader(); |
144 | EXPECT_TRUE(deserialize()); |
145 | } |
146 | |
147 | TEST_F(DeserializationTest, ZeroWordCountFailure) { |
148 | addHeader(); |
149 | binary.push_back(Elt: 0); // OpNop with zero word count |
150 | |
151 | ASSERT_FALSE(deserialize()); |
152 | expectDiagnostic(errorMessage: "word count cannot be zero" ); |
153 | } |
154 | |
155 | TEST_F(DeserializationTest, InsufficientWordFailure) { |
156 | addHeader(); |
157 | binary.push_back((2u << 16) | |
158 | static_cast<uint32_t>(spirv::Opcode::OpTypeVoid)); |
159 | // Missing word for type <id>. |
160 | |
161 | ASSERT_FALSE(deserialize()); |
162 | expectDiagnostic(errorMessage: "insufficient words for the last instruction" ); |
163 | } |
164 | |
165 | //===----------------------------------------------------------------------===// |
166 | // Types |
167 | //===----------------------------------------------------------------------===// |
168 | |
169 | TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) { |
170 | addHeader(); |
171 | addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); |
172 | |
173 | ASSERT_FALSE(deserialize()); |
174 | expectDiagnostic(errorMessage: "OpTypeInt must have bitwidth and signedness parameters" ); |
175 | } |
176 | |
177 | //===----------------------------------------------------------------------===// |
178 | // StructType |
179 | //===----------------------------------------------------------------------===// |
180 | |
181 | TEST_F(DeserializationTest, OpMemberNameSuccess) { |
182 | addHeader(); |
183 | SmallVector<uint32_t, 5> typeDecl; |
184 | std::swap(LHS&: typeDecl, RHS&: binary); |
185 | |
186 | auto int32Type = addIntType(bitwidth: 32); |
187 | auto structType = addStructType(memberTypes: {int32Type, int32Type}); |
188 | std::swap(LHS&: typeDecl, RHS&: binary); |
189 | |
190 | SmallVector<uint32_t, 5> operands1 = {structType, 0}; |
191 | (void)spirv::encodeStringLiteralInto(binary&: operands1, literal: "i1" ); |
192 | addInstruction(spirv::Opcode::OpMemberName, operands1); |
193 | |
194 | SmallVector<uint32_t, 5> operands2 = {structType, 1}; |
195 | (void)spirv::encodeStringLiteralInto(binary&: operands2, literal: "i2" ); |
196 | addInstruction(spirv::Opcode::OpMemberName, operands2); |
197 | |
198 | binary.append(in_start: typeDecl.begin(), in_end: typeDecl.end()); |
199 | EXPECT_TRUE(deserialize()); |
200 | } |
201 | |
202 | TEST_F(DeserializationTest, OpMemberNameMissingOperands) { |
203 | addHeader(); |
204 | SmallVector<uint32_t, 5> typeDecl; |
205 | std::swap(LHS&: typeDecl, RHS&: binary); |
206 | |
207 | auto int32Type = addIntType(bitwidth: 32); |
208 | auto int64Type = addIntType(bitwidth: 64); |
209 | auto structType = addStructType(memberTypes: {int32Type, int64Type}); |
210 | std::swap(LHS&: typeDecl, RHS&: binary); |
211 | |
212 | SmallVector<uint32_t, 5> operands1 = {structType}; |
213 | addInstruction(spirv::Opcode::OpMemberName, operands1); |
214 | |
215 | binary.append(in_start: typeDecl.begin(), in_end: typeDecl.end()); |
216 | ASSERT_FALSE(deserialize()); |
217 | expectDiagnostic(errorMessage: "OpMemberName must have at least 3 operands" ); |
218 | } |
219 | |
220 | TEST_F(DeserializationTest, OpMemberNameExcessOperands) { |
221 | addHeader(); |
222 | SmallVector<uint32_t, 5> typeDecl; |
223 | std::swap(LHS&: typeDecl, RHS&: binary); |
224 | |
225 | auto int32Type = addIntType(bitwidth: 32); |
226 | auto structType = addStructType(memberTypes: {int32Type}); |
227 | std::swap(LHS&: typeDecl, RHS&: binary); |
228 | |
229 | SmallVector<uint32_t, 5> operands = {structType, 0}; |
230 | (void)spirv::encodeStringLiteralInto(binary&: operands, literal: "int32" ); |
231 | operands.push_back(Elt: 42); |
232 | addInstruction(spirv::Opcode::OpMemberName, operands); |
233 | |
234 | binary.append(in_start: typeDecl.begin(), in_end: typeDecl.end()); |
235 | ASSERT_FALSE(deserialize()); |
236 | expectDiagnostic(errorMessage: "unexpected trailing words in OpMemberName instruction" ); |
237 | } |
238 | |
239 | //===----------------------------------------------------------------------===// |
240 | // Functions |
241 | //===----------------------------------------------------------------------===// |
242 | |
243 | TEST_F(DeserializationTest, FunctionMissingEndFailure) { |
244 | addHeader(); |
245 | auto voidType = addVoidType(); |
246 | auto fnType = addFunctionType(retType: voidType, paramTypes: {}); |
247 | addFunction(retType: voidType, fnType); |
248 | // Missing OpFunctionEnd. |
249 | |
250 | ASSERT_FALSE(deserialize()); |
251 | expectDiagnostic(errorMessage: "expected OpFunctionEnd instruction" ); |
252 | } |
253 | |
254 | TEST_F(DeserializationTest, FunctionMissingParameterFailure) { |
255 | addHeader(); |
256 | auto voidType = addVoidType(); |
257 | auto i32Type = addIntType(bitwidth: 32); |
258 | auto fnType = addFunctionType(retType: voidType, paramTypes: {i32Type}); |
259 | addFunction(retType: voidType, fnType); |
260 | // Missing OpFunctionParameter. |
261 | |
262 | ASSERT_FALSE(deserialize()); |
263 | expectDiagnostic(errorMessage: "expected OpFunctionParameter instruction" ); |
264 | } |
265 | |
266 | TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) { |
267 | addHeader(); |
268 | auto voidType = addVoidType(); |
269 | auto fnType = addFunctionType(retType: voidType, paramTypes: {}); |
270 | addFunction(retType: voidType, fnType); |
271 | // Missing OpLabel. |
272 | addReturn(); |
273 | addFunctionEnd(); |
274 | |
275 | ASSERT_FALSE(deserialize()); |
276 | expectDiagnostic(errorMessage: "a basic block must start with OpLabel" ); |
277 | } |
278 | |
279 | TEST_F(DeserializationTest, FunctionMalformedLabelFailure) { |
280 | addHeader(); |
281 | auto voidType = addVoidType(); |
282 | auto fnType = addFunctionType(retType: voidType, paramTypes: {}); |
283 | addFunction(retType: voidType, fnType); |
284 | addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel |
285 | addReturn(); |
286 | addFunctionEnd(); |
287 | |
288 | ASSERT_FALSE(deserialize()); |
289 | expectDiagnostic(errorMessage: "OpLabel should only have result <id>" ); |
290 | } |
291 | |