1//===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The purpose of this file is to provide negative deserialization tests.
10// For positive deserialization tests, please use serialization and
11// deserialization for roundtripping.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Target/SPIRV/Deserialization.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/IR/Diagnostics.h"
19#include "mlir/IR/MLIRContext.h"
20#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
21#include "gmock/gmock.h"
22
23#include <memory>
24
25using namespace mlir;
26
27using ::testing::StrEq;
28
29//===----------------------------------------------------------------------===//
30// Test Fixture
31//===----------------------------------------------------------------------===//
32
33/// A deserialization test fixture providing minimal SPIR-V building and
34/// diagnostic checking utilities.
35class DeserializationTest : public ::testing::Test {
36protected:
37 DeserializationTest() {
38 context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
39 // Register a diagnostic handler to capture the diagnostic so that we can
40 // check it later.
41 context.getDiagEngine().registerHandler(handler: [&](Diagnostic &diag) {
42 diagnostic = std::make_unique<Diagnostic>(args: std::move(diag));
43 });
44 }
45
46 /// Performs deserialization and returns the constructed spirv.module op.
47 OwningOpRef<spirv::ModuleOp> deserialize() {
48 return spirv::deserialize(binary, context: &context);
49 }
50
51 /// Checks there is a diagnostic generated with the given `errorMessage`.
52 void expectDiagnostic(StringRef errorMessage) {
53 ASSERT_NE(nullptr, diagnostic.get());
54
55 // TODO: check error location too.
56 EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage)));
57 }
58
59 //===--------------------------------------------------------------------===//
60 // SPIR-V builder methods
61 //===--------------------------------------------------------------------===//
62
63 /// Adds the SPIR-V module header to `binary`.
64 void addHeader() {
65 spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0);
66 }
67
68 /// Adds the SPIR-V instruction into `binary`.
69 void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
70 uint32_t wordCount = 1 + operands.size();
71 binary.push_back(spirv::Elt: getPrefixedOpcode(wordCount, op));
72 binary.append(in_start: operands.begin(), in_end: operands.end());
73 }
74
75 uint32_t addVoidType() {
76 auto id = nextID++;
77 addInstruction(spirv::Opcode::OpTypeVoid, {id});
78 return id;
79 }
80
81 uint32_t addIntType(uint32_t bitwidth) {
82 auto id = nextID++;
83 addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
84 return id;
85 }
86
87 uint32_t addStructType(ArrayRef<uint32_t> memberTypes) {
88 auto id = nextID++;
89 SmallVector<uint32_t, 2> words;
90 words.push_back(Elt: id);
91 words.append(in_start: memberTypes.begin(), in_end: memberTypes.end());
92 addInstruction(spirv::Opcode::OpTypeStruct, words);
93 return id;
94 }
95
96 uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
97 auto id = nextID++;
98 SmallVector<uint32_t, 4> operands;
99 operands.push_back(Elt: id);
100 operands.push_back(Elt: retType);
101 operands.append(in_start: paramTypes.begin(), in_end: paramTypes.end());
102 addInstruction(spirv::Opcode::OpTypeFunction, operands);
103 return id;
104 }
105
106 uint32_t addFunction(uint32_t retType, uint32_t fnType) {
107 auto id = nextID++;
108 addInstruction(spirv::Opcode::OpFunction,
109 {retType, id,
110 static_cast<uint32_t>(spirv::FunctionControl::None),
111 fnType});
112 return id;
113 }
114
115 void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); }
116
117 void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); }
118
119protected:
120 SmallVector<uint32_t, 5> binary;
121 uint32_t nextID = 1;
122 MLIRContext context;
123 std::unique_ptr<Diagnostic> diagnostic;
124};
125
126//===----------------------------------------------------------------------===//
127// Basics
128//===----------------------------------------------------------------------===//
129
130TEST_F(DeserializationTest, EmptyModuleFailure) {
131 ASSERT_FALSE(deserialize());
132 expectDiagnostic(errorMessage: "SPIR-V binary module must have a 5-word header");
133}
134
135TEST_F(DeserializationTest, WrongMagicNumberFailure) {
136 addHeader();
137 binary.front() = 0xdeadbeef; // Change to a wrong magic number
138 ASSERT_FALSE(deserialize());
139 expectDiagnostic(errorMessage: "incorrect magic number");
140}
141
142TEST_F(DeserializationTest, OnlyHeaderSuccess) {
143 addHeader();
144 EXPECT_TRUE(deserialize());
145}
146
147TEST_F(DeserializationTest, ZeroWordCountFailure) {
148 addHeader();
149 binary.push_back(Elt: 0); // OpNop with zero word count
150
151 ASSERT_FALSE(deserialize());
152 expectDiagnostic(errorMessage: "word count cannot be zero");
153}
154
155TEST_F(DeserializationTest, InsufficientWordFailure) {
156 addHeader();
157 binary.push_back((2u << 16) |
158 static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
159 // Missing word for type <id>.
160
161 ASSERT_FALSE(deserialize());
162 expectDiagnostic(errorMessage: "insufficient words for the last instruction");
163}
164
165//===----------------------------------------------------------------------===//
166// Types
167//===----------------------------------------------------------------------===//
168
169TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
170 addHeader();
171 addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
172
173 ASSERT_FALSE(deserialize());
174 expectDiagnostic(errorMessage: "OpTypeInt must have bitwidth and signedness parameters");
175}
176
177//===----------------------------------------------------------------------===//
178// StructType
179//===----------------------------------------------------------------------===//
180
181TEST_F(DeserializationTest, OpMemberNameSuccess) {
182 addHeader();
183 SmallVector<uint32_t, 5> typeDecl;
184 std::swap(LHS&: typeDecl, RHS&: binary);
185
186 auto int32Type = addIntType(bitwidth: 32);
187 auto structType = addStructType(memberTypes: {int32Type, int32Type});
188 std::swap(LHS&: typeDecl, RHS&: binary);
189
190 SmallVector<uint32_t, 5> operands1 = {structType, 0};
191 (void)spirv::encodeStringLiteralInto(binary&: operands1, literal: "i1");
192 addInstruction(spirv::Opcode::OpMemberName, operands1);
193
194 SmallVector<uint32_t, 5> operands2 = {structType, 1};
195 (void)spirv::encodeStringLiteralInto(binary&: operands2, literal: "i2");
196 addInstruction(spirv::Opcode::OpMemberName, operands2);
197
198 binary.append(in_start: typeDecl.begin(), in_end: typeDecl.end());
199 EXPECT_TRUE(deserialize());
200}
201
202TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
203 addHeader();
204 SmallVector<uint32_t, 5> typeDecl;
205 std::swap(LHS&: typeDecl, RHS&: binary);
206
207 auto int32Type = addIntType(bitwidth: 32);
208 auto int64Type = addIntType(bitwidth: 64);
209 auto structType = addStructType(memberTypes: {int32Type, int64Type});
210 std::swap(LHS&: typeDecl, RHS&: binary);
211
212 SmallVector<uint32_t, 5> operands1 = {structType};
213 addInstruction(spirv::Opcode::OpMemberName, operands1);
214
215 binary.append(in_start: typeDecl.begin(), in_end: typeDecl.end());
216 ASSERT_FALSE(deserialize());
217 expectDiagnostic(errorMessage: "OpMemberName must have at least 3 operands");
218}
219
220TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
221 addHeader();
222 SmallVector<uint32_t, 5> typeDecl;
223 std::swap(LHS&: typeDecl, RHS&: binary);
224
225 auto int32Type = addIntType(bitwidth: 32);
226 auto structType = addStructType(memberTypes: {int32Type});
227 std::swap(LHS&: typeDecl, RHS&: binary);
228
229 SmallVector<uint32_t, 5> operands = {structType, 0};
230 (void)spirv::encodeStringLiteralInto(binary&: operands, literal: "int32");
231 operands.push_back(Elt: 42);
232 addInstruction(spirv::Opcode::OpMemberName, operands);
233
234 binary.append(in_start: typeDecl.begin(), in_end: typeDecl.end());
235 ASSERT_FALSE(deserialize());
236 expectDiagnostic(errorMessage: "unexpected trailing words in OpMemberName instruction");
237}
238
239//===----------------------------------------------------------------------===//
240// Functions
241//===----------------------------------------------------------------------===//
242
243TEST_F(DeserializationTest, FunctionMissingEndFailure) {
244 addHeader();
245 auto voidType = addVoidType();
246 auto fnType = addFunctionType(retType: voidType, paramTypes: {});
247 addFunction(retType: voidType, fnType);
248 // Missing OpFunctionEnd.
249
250 ASSERT_FALSE(deserialize());
251 expectDiagnostic(errorMessage: "expected OpFunctionEnd instruction");
252}
253
254TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
255 addHeader();
256 auto voidType = addVoidType();
257 auto i32Type = addIntType(bitwidth: 32);
258 auto fnType = addFunctionType(retType: voidType, paramTypes: {i32Type});
259 addFunction(retType: voidType, fnType);
260 // Missing OpFunctionParameter.
261
262 ASSERT_FALSE(deserialize());
263 expectDiagnostic(errorMessage: "expected OpFunctionParameter instruction");
264}
265
266TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
267 addHeader();
268 auto voidType = addVoidType();
269 auto fnType = addFunctionType(retType: voidType, paramTypes: {});
270 addFunction(retType: voidType, fnType);
271 // Missing OpLabel.
272 addReturn();
273 addFunctionEnd();
274
275 ASSERT_FALSE(deserialize());
276 expectDiagnostic(errorMessage: "a basic block must start with OpLabel");
277}
278
279TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
280 addHeader();
281 auto voidType = addVoidType();
282 auto fnType = addFunctionType(retType: voidType, paramTypes: {});
283 addFunction(retType: voidType, fnType);
284 addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel
285 addReturn();
286 addFunctionEnd();
287
288 ASSERT_FALSE(deserialize());
289 expectDiagnostic(errorMessage: "OpLabel should only have result <id>");
290}
291

source code of mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp