| 1 | //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // The purpose of this file is to provide negative deserialization tests. |
| 10 | // For positive deserialization tests, please use serialization and |
| 11 | // deserialization for roundtripping. |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "mlir/Target/SPIRV/Deserialization.h" |
| 16 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| 17 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 18 | #include "mlir/IR/Diagnostics.h" |
| 19 | #include "mlir/IR/MLIRContext.h" |
| 20 | #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" |
| 21 | #include "gmock/gmock.h" |
| 22 | |
| 23 | #include <memory> |
| 24 | |
| 25 | using namespace mlir; |
| 26 | |
| 27 | using ::testing::StrEq; |
| 28 | |
| 29 | //===----------------------------------------------------------------------===// |
| 30 | // Test Fixture |
| 31 | //===----------------------------------------------------------------------===// |
| 32 | |
| 33 | /// A deserialization test fixture providing minimal SPIR-V building and |
| 34 | /// diagnostic checking utilities. |
| 35 | class DeserializationTest : public ::testing::Test { |
| 36 | protected: |
| 37 | DeserializationTest() { |
| 38 | context.getOrLoadDialect<mlir::spirv::SPIRVDialect>(); |
| 39 | // Register a diagnostic handler to capture the diagnostic so that we can |
| 40 | // check it later. |
| 41 | context.getDiagEngine().registerHandler(handler: [&](Diagnostic &diag) { |
| 42 | diagnostic = std::make_unique<Diagnostic>(args: std::move(diag)); |
| 43 | }); |
| 44 | } |
| 45 | |
| 46 | /// Performs deserialization and returns the constructed spirv.module op. |
| 47 | OwningOpRef<spirv::ModuleOp> deserialize() { |
| 48 | return spirv::deserialize(binary, context: &context); |
| 49 | } |
| 50 | |
| 51 | /// Checks there is a diagnostic generated with the given `errorMessage`. |
| 52 | void expectDiagnostic(StringRef errorMessage) { |
| 53 | ASSERT_NE(nullptr, diagnostic.get()); |
| 54 | |
| 55 | // TODO: check error location too. |
| 56 | EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage))); |
| 57 | } |
| 58 | |
| 59 | //===--------------------------------------------------------------------===// |
| 60 | // SPIR-V builder methods |
| 61 | //===--------------------------------------------------------------------===// |
| 62 | |
| 63 | /// Adds the SPIR-V module header to `binary`. |
| 64 | void () { |
| 65 | spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0); |
| 66 | } |
| 67 | |
| 68 | /// Adds the SPIR-V instruction into `binary`. |
| 69 | void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) { |
| 70 | uint32_t wordCount = 1 + operands.size(); |
| 71 | binary.push_back(spirv::Elt: getPrefixedOpcode(wordCount, op)); |
| 72 | binary.append(in_start: operands.begin(), in_end: operands.end()); |
| 73 | } |
| 74 | |
| 75 | uint32_t addVoidType() { |
| 76 | auto id = nextID++; |
| 77 | addInstruction(spirv::Opcode::OpTypeVoid, {id}); |
| 78 | return id; |
| 79 | } |
| 80 | |
| 81 | uint32_t addIntType(uint32_t bitwidth) { |
| 82 | auto id = nextID++; |
| 83 | addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1}); |
| 84 | return id; |
| 85 | } |
| 86 | |
| 87 | uint32_t addStructType(ArrayRef<uint32_t> memberTypes) { |
| 88 | auto id = nextID++; |
| 89 | SmallVector<uint32_t, 2> words; |
| 90 | words.push_back(Elt: id); |
| 91 | words.append(in_start: memberTypes.begin(), in_end: memberTypes.end()); |
| 92 | addInstruction(spirv::Opcode::OpTypeStruct, words); |
| 93 | return id; |
| 94 | } |
| 95 | |
| 96 | uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) { |
| 97 | auto id = nextID++; |
| 98 | SmallVector<uint32_t, 4> operands; |
| 99 | operands.push_back(Elt: id); |
| 100 | operands.push_back(Elt: retType); |
| 101 | operands.append(in_start: paramTypes.begin(), in_end: paramTypes.end()); |
| 102 | addInstruction(spirv::Opcode::OpTypeFunction, operands); |
| 103 | return id; |
| 104 | } |
| 105 | |
| 106 | uint32_t addFunction(uint32_t retType, uint32_t fnType) { |
| 107 | auto id = nextID++; |
| 108 | addInstruction(spirv::Opcode::OpFunction, |
| 109 | {retType, id, |
| 110 | static_cast<uint32_t>(spirv::FunctionControl::None), |
| 111 | fnType}); |
| 112 | return id; |
| 113 | } |
| 114 | |
| 115 | void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); } |
| 116 | |
| 117 | void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); } |
| 118 | |
| 119 | protected: |
| 120 | SmallVector<uint32_t, 5> binary; |
| 121 | uint32_t nextID = 1; |
| 122 | MLIRContext context; |
| 123 | std::unique_ptr<Diagnostic> diagnostic; |
| 124 | }; |
| 125 | |
| 126 | //===----------------------------------------------------------------------===// |
| 127 | // Basics |
| 128 | //===----------------------------------------------------------------------===// |
| 129 | |
| 130 | TEST_F(DeserializationTest, EmptyModuleFailure) { |
| 131 | ASSERT_FALSE(deserialize()); |
| 132 | expectDiagnostic(errorMessage: "SPIR-V binary module must have a 5-word header" ); |
| 133 | } |
| 134 | |
| 135 | TEST_F(DeserializationTest, WrongMagicNumberFailure) { |
| 136 | addHeader(); |
| 137 | binary.front() = 0xdeadbeef; // Change to a wrong magic number |
| 138 | ASSERT_FALSE(deserialize()); |
| 139 | expectDiagnostic(errorMessage: "incorrect magic number" ); |
| 140 | } |
| 141 | |
| 142 | TEST_F(DeserializationTest, OnlyHeaderSuccess) { |
| 143 | addHeader(); |
| 144 | EXPECT_TRUE(deserialize()); |
| 145 | } |
| 146 | |
| 147 | TEST_F(DeserializationTest, ZeroWordCountFailure) { |
| 148 | addHeader(); |
| 149 | binary.push_back(Elt: 0); // OpNop with zero word count |
| 150 | |
| 151 | ASSERT_FALSE(deserialize()); |
| 152 | expectDiagnostic(errorMessage: "word count cannot be zero" ); |
| 153 | } |
| 154 | |
| 155 | TEST_F(DeserializationTest, InsufficientWordFailure) { |
| 156 | addHeader(); |
| 157 | binary.push_back((2u << 16) | |
| 158 | static_cast<uint32_t>(spirv::Opcode::OpTypeVoid)); |
| 159 | // Missing word for type <id>. |
| 160 | |
| 161 | ASSERT_FALSE(deserialize()); |
| 162 | expectDiagnostic(errorMessage: "insufficient words for the last instruction" ); |
| 163 | } |
| 164 | |
| 165 | //===----------------------------------------------------------------------===// |
| 166 | // Types |
| 167 | //===----------------------------------------------------------------------===// |
| 168 | |
| 169 | TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) { |
| 170 | addHeader(); |
| 171 | addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); |
| 172 | |
| 173 | ASSERT_FALSE(deserialize()); |
| 174 | expectDiagnostic(errorMessage: "OpTypeInt must have bitwidth and signedness parameters" ); |
| 175 | } |
| 176 | |
| 177 | //===----------------------------------------------------------------------===// |
| 178 | // StructType |
| 179 | //===----------------------------------------------------------------------===// |
| 180 | |
| 181 | TEST_F(DeserializationTest, OpMemberNameSuccess) { |
| 182 | addHeader(); |
| 183 | SmallVector<uint32_t, 5> typeDecl; |
| 184 | std::swap(LHS&: typeDecl, RHS&: binary); |
| 185 | |
| 186 | auto int32Type = addIntType(bitwidth: 32); |
| 187 | auto structType = addStructType(memberTypes: {int32Type, int32Type}); |
| 188 | std::swap(LHS&: typeDecl, RHS&: binary); |
| 189 | |
| 190 | SmallVector<uint32_t, 5> operands1 = {structType, 0}; |
| 191 | (void)spirv::encodeStringLiteralInto(binary&: operands1, literal: "i1" ); |
| 192 | addInstruction(spirv::Opcode::OpMemberName, operands1); |
| 193 | |
| 194 | SmallVector<uint32_t, 5> operands2 = {structType, 1}; |
| 195 | (void)spirv::encodeStringLiteralInto(binary&: operands2, literal: "i2" ); |
| 196 | addInstruction(spirv::Opcode::OpMemberName, operands2); |
| 197 | |
| 198 | binary.append(in_start: typeDecl.begin(), in_end: typeDecl.end()); |
| 199 | EXPECT_TRUE(deserialize()); |
| 200 | } |
| 201 | |
| 202 | TEST_F(DeserializationTest, OpMemberNameMissingOperands) { |
| 203 | addHeader(); |
| 204 | SmallVector<uint32_t, 5> typeDecl; |
| 205 | std::swap(LHS&: typeDecl, RHS&: binary); |
| 206 | |
| 207 | auto int32Type = addIntType(bitwidth: 32); |
| 208 | auto int64Type = addIntType(bitwidth: 64); |
| 209 | auto structType = addStructType(memberTypes: {int32Type, int64Type}); |
| 210 | std::swap(LHS&: typeDecl, RHS&: binary); |
| 211 | |
| 212 | SmallVector<uint32_t, 5> operands1 = {structType}; |
| 213 | addInstruction(spirv::Opcode::OpMemberName, operands1); |
| 214 | |
| 215 | binary.append(in_start: typeDecl.begin(), in_end: typeDecl.end()); |
| 216 | ASSERT_FALSE(deserialize()); |
| 217 | expectDiagnostic(errorMessage: "OpMemberName must have at least 3 operands" ); |
| 218 | } |
| 219 | |
| 220 | TEST_F(DeserializationTest, OpMemberNameExcessOperands) { |
| 221 | addHeader(); |
| 222 | SmallVector<uint32_t, 5> typeDecl; |
| 223 | std::swap(LHS&: typeDecl, RHS&: binary); |
| 224 | |
| 225 | auto int32Type = addIntType(bitwidth: 32); |
| 226 | auto structType = addStructType(memberTypes: {int32Type}); |
| 227 | std::swap(LHS&: typeDecl, RHS&: binary); |
| 228 | |
| 229 | SmallVector<uint32_t, 5> operands = {structType, 0}; |
| 230 | (void)spirv::encodeStringLiteralInto(binary&: operands, literal: "int32" ); |
| 231 | operands.push_back(Elt: 42); |
| 232 | addInstruction(spirv::Opcode::OpMemberName, operands); |
| 233 | |
| 234 | binary.append(in_start: typeDecl.begin(), in_end: typeDecl.end()); |
| 235 | ASSERT_FALSE(deserialize()); |
| 236 | expectDiagnostic(errorMessage: "unexpected trailing words in OpMemberName instruction" ); |
| 237 | } |
| 238 | |
| 239 | //===----------------------------------------------------------------------===// |
| 240 | // Functions |
| 241 | //===----------------------------------------------------------------------===// |
| 242 | |
| 243 | TEST_F(DeserializationTest, FunctionMissingEndFailure) { |
| 244 | addHeader(); |
| 245 | auto voidType = addVoidType(); |
| 246 | auto fnType = addFunctionType(retType: voidType, paramTypes: {}); |
| 247 | addFunction(retType: voidType, fnType); |
| 248 | // Missing OpFunctionEnd. |
| 249 | |
| 250 | ASSERT_FALSE(deserialize()); |
| 251 | expectDiagnostic(errorMessage: "expected OpFunctionEnd instruction" ); |
| 252 | } |
| 253 | |
| 254 | TEST_F(DeserializationTest, FunctionMissingParameterFailure) { |
| 255 | addHeader(); |
| 256 | auto voidType = addVoidType(); |
| 257 | auto i32Type = addIntType(bitwidth: 32); |
| 258 | auto fnType = addFunctionType(retType: voidType, paramTypes: {i32Type}); |
| 259 | addFunction(retType: voidType, fnType); |
| 260 | // Missing OpFunctionParameter. |
| 261 | |
| 262 | ASSERT_FALSE(deserialize()); |
| 263 | expectDiagnostic(errorMessage: "expected OpFunctionParameter instruction" ); |
| 264 | } |
| 265 | |
| 266 | TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) { |
| 267 | addHeader(); |
| 268 | auto voidType = addVoidType(); |
| 269 | auto fnType = addFunctionType(retType: voidType, paramTypes: {}); |
| 270 | addFunction(retType: voidType, fnType); |
| 271 | // Missing OpLabel. |
| 272 | addReturn(); |
| 273 | addFunctionEnd(); |
| 274 | |
| 275 | ASSERT_FALSE(deserialize()); |
| 276 | expectDiagnostic(errorMessage: "a basic block must start with OpLabel" ); |
| 277 | } |
| 278 | |
| 279 | TEST_F(DeserializationTest, FunctionMalformedLabelFailure) { |
| 280 | addHeader(); |
| 281 | auto voidType = addVoidType(); |
| 282 | auto fnType = addFunctionType(retType: voidType, paramTypes: {}); |
| 283 | addFunction(retType: voidType, fnType); |
| 284 | addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel |
| 285 | addReturn(); |
| 286 | addFunctionEnd(); |
| 287 | |
| 288 | ASSERT_FALSE(deserialize()); |
| 289 | expectDiagnostic(errorMessage: "OpLabel should only have result <id>" ); |
| 290 | } |
| 291 | |