| 1 | //===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file contains corner case tests for the SPIR-V serializer that are not |
| 10 | // covered by normal serialization and deserialization roundtripping. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Target/SPIRV/Serialization.h" |
| 15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
| 16 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| 17 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
| 18 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 19 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
| 20 | #include "mlir/IR/Builders.h" |
| 21 | #include "mlir/IR/Location.h" |
| 22 | #include "mlir/IR/MLIRContext.h" |
| 23 | #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" |
| 24 | #include "llvm/ADT/DenseSet.h" |
| 25 | #include "llvm/ADT/STLExtras.h" |
| 26 | #include "llvm/ADT/Sequence.h" |
| 27 | #include "llvm/ADT/SmallVector.h" |
| 28 | #include "llvm/ADT/StringRef.h" |
| 29 | #include "gmock/gmock.h" |
| 30 | |
| 31 | using namespace mlir; |
| 32 | |
| 33 | //===----------------------------------------------------------------------===// |
| 34 | // Test Fixture |
| 35 | //===----------------------------------------------------------------------===// |
| 36 | |
| 37 | class SerializationTest : public ::testing::Test { |
| 38 | protected: |
| 39 | SerializationTest() { |
| 40 | context.getOrLoadDialect<mlir::spirv::SPIRVDialect>(); |
| 41 | initModuleOp(); |
| 42 | } |
| 43 | |
| 44 | /// Initializes an empty SPIR-V module op. |
| 45 | void initModuleOp() { |
| 46 | OpBuilder builder(&context); |
| 47 | OperationState state(UnknownLoc::get(&context), |
| 48 | spirv::ModuleOp::getOperationName()); |
| 49 | state.addAttribute("addressing_model" , |
| 50 | builder.getAttr<spirv::AddressingModelAttr>( |
| 51 | spirv::AddressingModel::Logical)); |
| 52 | state.addAttribute("memory_model" , builder.getAttr<spirv::MemoryModelAttr>( |
| 53 | spirv::MemoryModel::GLSL450)); |
| 54 | state.addAttribute("vce_triple" , |
| 55 | spirv::VerCapExtAttr::get( |
| 56 | spirv::Version::V_1_0, ArrayRef<spirv::Capability>(), |
| 57 | ArrayRef<spirv::Extension>(), &context)); |
| 58 | spirv::ModuleOp::build(builder, state); |
| 59 | module = cast<spirv::ModuleOp>(Operation::create(state)); |
| 60 | } |
| 61 | |
| 62 | /// Gets the `struct { float }` type. |
| 63 | spirv::StructType getFloatStructType() { |
| 64 | OpBuilder builder(module->getRegion()); |
| 65 | llvm::SmallVector<Type, 1> elementTypes{builder.getF32Type()}; |
| 66 | llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0}; |
| 67 | return spirv::StructType::get(memberTypes: elementTypes, offsetInfo); |
| 68 | } |
| 69 | |
| 70 | /// Inserts a global variable of the given `type` and `name`. |
| 71 | spirv::GlobalVariableOp addGlobalVar(Type type, llvm::StringRef name) { |
| 72 | OpBuilder builder(module->getRegion()); |
| 73 | auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); |
| 74 | return builder.create<spirv::GlobalVariableOp>( |
| 75 | UnknownLoc::get(&context), TypeAttr::get(ptrType), |
| 76 | builder.getStringAttr(name), nullptr); |
| 77 | } |
| 78 | |
| 79 | // Inserts an Integer or a Vector of Integers constant of value 'val'. |
| 80 | spirv::ConstantOp addConstInt(Type type, const APInt &val) { |
| 81 | OpBuilder builder(module->getRegion()); |
| 82 | auto loc = UnknownLoc::get(&context); |
| 83 | |
| 84 | if (auto intType = dyn_cast<IntegerType>(type)) { |
| 85 | return builder.create<spirv::ConstantOp>( |
| 86 | loc, type, builder.getIntegerAttr(type, val)); |
| 87 | } |
| 88 | if (auto vectorType = dyn_cast<VectorType>(type)) { |
| 89 | Type elemType = vectorType.getElementType(); |
| 90 | if (auto intType = dyn_cast<IntegerType>(elemType)) { |
| 91 | return builder.create<spirv::ConstantOp>( |
| 92 | loc, type, |
| 93 | DenseElementsAttr::get(vectorType, |
| 94 | IntegerAttr::get(elemType, val).getValue())); |
| 95 | } |
| 96 | } |
| 97 | llvm_unreachable("unimplemented types for AddConstInt()" ); |
| 98 | } |
| 99 | |
| 100 | /// Handles a SPIR-V instruction with the given `opcode` and `operand`. |
| 101 | /// Returns true to interrupt. |
| 102 | using HandleFn = llvm::function_ref<bool(spirv::Opcode opcode, |
| 103 | ArrayRef<uint32_t> operands)>; |
| 104 | |
| 105 | /// Returns true if we can find a matching instruction in the SPIR-V blob. |
| 106 | bool scanInstruction(HandleFn handleFn) { |
| 107 | auto binarySize = binary.size(); |
| 108 | auto *begin = binary.begin(); |
| 109 | auto currOffset = spirv::kHeaderWordCount; |
| 110 | |
| 111 | while (currOffset < binarySize) { |
| 112 | auto wordCount = binary[currOffset] >> 16; |
| 113 | if (!wordCount || (currOffset + wordCount > binarySize)) |
| 114 | return false; |
| 115 | |
| 116 | spirv::Opcode opcode = |
| 117 | static_cast<spirv::Opcode>(binary[currOffset] & 0xffff); |
| 118 | llvm::ArrayRef<uint32_t> operands(begin + currOffset + 1, |
| 119 | begin + currOffset + wordCount); |
| 120 | if (handleFn(opcode, operands)) |
| 121 | return true; |
| 122 | |
| 123 | currOffset += wordCount; |
| 124 | } |
| 125 | return false; |
| 126 | } |
| 127 | |
| 128 | protected: |
| 129 | MLIRContext context; |
| 130 | OwningOpRef<spirv::ModuleOp> module; |
| 131 | SmallVector<uint32_t, 0> binary; |
| 132 | }; |
| 133 | |
| 134 | //===----------------------------------------------------------------------===// |
| 135 | // Block decoration |
| 136 | //===----------------------------------------------------------------------===// |
| 137 | |
| 138 | TEST_F(SerializationTest, ContainsBlockDecoration) { |
| 139 | auto structType = getFloatStructType(); |
| 140 | addGlobalVar(structType, "var0" ); |
| 141 | |
| 142 | ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); |
| 143 | |
| 144 | auto hasBlockDecoration = [](spirv::Opcode opcode, |
| 145 | ArrayRef<uint32_t> operands) { |
| 146 | return opcode == spirv::Opcode::OpDecorate && operands.size() == 2 && |
| 147 | operands[1] == static_cast<uint32_t>(spirv::Decoration::Block); |
| 148 | }; |
| 149 | EXPECT_TRUE(scanInstruction(hasBlockDecoration)); |
| 150 | } |
| 151 | |
| 152 | TEST_F(SerializationTest, ContainsNoDuplicatedBlockDecoration) { |
| 153 | auto structType = getFloatStructType(); |
| 154 | // Two global variables using the same type should not decorate the type with |
| 155 | // duplicated `Block` decorations. |
| 156 | addGlobalVar(structType, "var0" ); |
| 157 | addGlobalVar(structType, "var1" ); |
| 158 | |
| 159 | ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); |
| 160 | |
| 161 | unsigned count = 0; |
| 162 | auto countBlockDecoration = [&count](spirv::Opcode opcode, |
| 163 | ArrayRef<uint32_t> operands) { |
| 164 | if (opcode == spirv::Opcode::OpDecorate && operands.size() == 2 && |
| 165 | operands[1] == static_cast<uint32_t>(spirv::Decoration::Block)) |
| 166 | ++count; |
| 167 | return false; |
| 168 | }; |
| 169 | ASSERT_FALSE(scanInstruction(countBlockDecoration)); |
| 170 | EXPECT_EQ(count, 1u); |
| 171 | } |
| 172 | |
| 173 | TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) { |
| 174 | |
| 175 | auto signlessInt16Type = |
| 176 | IntegerType::get(&context, 16, IntegerType::Signless); |
| 177 | auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed); |
| 178 | // Check the bit extension of same value under different signedness semantics. |
| 179 | APInt signlessIntConstVal(signlessInt16Type.getWidth(), 0xffff, |
| 180 | signlessInt16Type.getSignedness()); |
| 181 | APInt signedIntConstVal(signedInt16Type.getWidth(), -1, |
| 182 | signedInt16Type.getSignedness()); |
| 183 | |
| 184 | addConstInt(signlessInt16Type, signlessIntConstVal); |
| 185 | addConstInt(signedInt16Type, signedIntConstVal); |
| 186 | ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); |
| 187 | |
| 188 | auto hasSignlessVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { |
| 189 | return opcode == spirv::Opcode::OpConstant && operands.size() == 3 && |
| 190 | operands[2] == 65535; |
| 191 | }; |
| 192 | EXPECT_TRUE(scanInstruction(hasSignlessVal)); |
| 193 | |
| 194 | auto hasSignedVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { |
| 195 | return opcode == spirv::Opcode::OpConstant && operands.size() == 3 && |
| 196 | operands[2] == 4294967295; |
| 197 | }; |
| 198 | EXPECT_TRUE(scanInstruction(hasSignedVal)); |
| 199 | } |
| 200 | |
| 201 | TEST_F(SerializationTest, ContainsSymbolName) { |
| 202 | auto structType = getFloatStructType(); |
| 203 | addGlobalVar(structType, "var0" ); |
| 204 | |
| 205 | spirv::SerializationOptions options; |
| 206 | options.emitSymbolName = true; |
| 207 | ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options))); |
| 208 | |
| 209 | auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { |
| 210 | unsigned index = 1; // Skip the result <id> |
| 211 | return opcode == spirv::Opcode::OpName && |
| 212 | spirv::decodeStringLiteral(operands, index) == "var0" ; |
| 213 | }; |
| 214 | EXPECT_TRUE(scanInstruction(hasVarName)); |
| 215 | } |
| 216 | |
| 217 | TEST_F(SerializationTest, DoesNotContainSymbolName) { |
| 218 | auto structType = getFloatStructType(); |
| 219 | addGlobalVar(structType, "var0" ); |
| 220 | |
| 221 | spirv::SerializationOptions options; |
| 222 | options.emitSymbolName = false; |
| 223 | ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options))); |
| 224 | |
| 225 | auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { |
| 226 | unsigned index = 1; // Skip the result <id> |
| 227 | return opcode == spirv::Opcode::OpName && |
| 228 | spirv::decodeStringLiteral(operands, index) == "var0" ; |
| 229 | }; |
| 230 | EXPECT_FALSE(scanInstruction(hasVarName)); |
| 231 | } |
| 232 | |