1 | //===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains corner case tests for the SPIR-V serializer that are not |
10 | // covered by normal serialization and deserialization roundtripping. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Target/SPIRV/Serialization.h" |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
19 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
20 | #include "mlir/IR/Builders.h" |
21 | #include "mlir/IR/Location.h" |
22 | #include "mlir/IR/MLIRContext.h" |
23 | #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" |
24 | #include "llvm/ADT/DenseSet.h" |
25 | #include "llvm/ADT/STLExtras.h" |
26 | #include "llvm/ADT/Sequence.h" |
27 | #include "llvm/ADT/SmallVector.h" |
28 | #include "llvm/ADT/StringRef.h" |
29 | #include "gmock/gmock.h" |
30 | |
31 | using namespace mlir; |
32 | |
33 | //===----------------------------------------------------------------------===// |
34 | // Test Fixture |
35 | //===----------------------------------------------------------------------===// |
36 | |
37 | class SerializationTest : public ::testing::Test { |
38 | protected: |
39 | SerializationTest() { |
40 | context.getOrLoadDialect<mlir::spirv::SPIRVDialect>(); |
41 | initModuleOp(); |
42 | } |
43 | |
44 | /// Initializes an empty SPIR-V module op. |
45 | void initModuleOp() { |
46 | OpBuilder builder(&context); |
47 | OperationState state(UnknownLoc::get(&context), |
48 | spirv::ModuleOp::getOperationName()); |
49 | state.addAttribute("addressing_model" , |
50 | builder.getAttr<spirv::AddressingModelAttr>( |
51 | spirv::AddressingModel::Logical)); |
52 | state.addAttribute("memory_model" , builder.getAttr<spirv::MemoryModelAttr>( |
53 | spirv::MemoryModel::GLSL450)); |
54 | state.addAttribute("vce_triple" , |
55 | spirv::VerCapExtAttr::get( |
56 | spirv::Version::V_1_0, ArrayRef<spirv::Capability>(), |
57 | ArrayRef<spirv::Extension>(), &context)); |
58 | spirv::ModuleOp::build(builder, state); |
59 | module = cast<spirv::ModuleOp>(Operation::create(state)); |
60 | } |
61 | |
62 | /// Gets the `struct { float }` type. |
63 | spirv::StructType getFloatStructType() { |
64 | OpBuilder builder(module->getRegion()); |
65 | llvm::SmallVector<Type, 1> elementTypes{builder.getF32Type()}; |
66 | llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0}; |
67 | return spirv::StructType::get(memberTypes: elementTypes, offsetInfo); |
68 | } |
69 | |
70 | /// Inserts a global variable of the given `type` and `name`. |
71 | spirv::GlobalVariableOp addGlobalVar(Type type, llvm::StringRef name) { |
72 | OpBuilder builder(module->getRegion()); |
73 | auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); |
74 | return builder.create<spirv::GlobalVariableOp>( |
75 | UnknownLoc::get(&context), TypeAttr::get(ptrType), |
76 | builder.getStringAttr(name), nullptr); |
77 | } |
78 | |
79 | // Inserts an Integer or a Vector of Integers constant of value 'val'. |
80 | spirv::ConstantOp addConstInt(Type type, const APInt &val) { |
81 | OpBuilder builder(module->getRegion()); |
82 | auto loc = UnknownLoc::get(&context); |
83 | |
84 | if (auto intType = dyn_cast<IntegerType>(type)) { |
85 | return builder.create<spirv::ConstantOp>( |
86 | loc, type, builder.getIntegerAttr(type, val)); |
87 | } |
88 | if (auto vectorType = dyn_cast<VectorType>(type)) { |
89 | Type elemType = vectorType.getElementType(); |
90 | if (auto intType = dyn_cast<IntegerType>(elemType)) { |
91 | return builder.create<spirv::ConstantOp>( |
92 | loc, type, |
93 | DenseElementsAttr::get(vectorType, |
94 | IntegerAttr::get(elemType, val).getValue())); |
95 | } |
96 | } |
97 | llvm_unreachable("unimplemented types for AddConstInt()" ); |
98 | } |
99 | |
100 | /// Handles a SPIR-V instruction with the given `opcode` and `operand`. |
101 | /// Returns true to interrupt. |
102 | using HandleFn = llvm::function_ref<bool(spirv::Opcode opcode, |
103 | ArrayRef<uint32_t> operands)>; |
104 | |
105 | /// Returns true if we can find a matching instruction in the SPIR-V blob. |
106 | bool scanInstruction(HandleFn handleFn) { |
107 | auto binarySize = binary.size(); |
108 | auto *begin = binary.begin(); |
109 | auto currOffset = spirv::kHeaderWordCount; |
110 | |
111 | while (currOffset < binarySize) { |
112 | auto wordCount = binary[currOffset] >> 16; |
113 | if (!wordCount || (currOffset + wordCount > binarySize)) |
114 | return false; |
115 | |
116 | spirv::Opcode opcode = |
117 | static_cast<spirv::Opcode>(binary[currOffset] & 0xffff); |
118 | llvm::ArrayRef<uint32_t> operands(begin + currOffset + 1, |
119 | begin + currOffset + wordCount); |
120 | if (handleFn(opcode, operands)) |
121 | return true; |
122 | |
123 | currOffset += wordCount; |
124 | } |
125 | return false; |
126 | } |
127 | |
128 | protected: |
129 | MLIRContext context; |
130 | OwningOpRef<spirv::ModuleOp> module; |
131 | SmallVector<uint32_t, 0> binary; |
132 | }; |
133 | |
134 | //===----------------------------------------------------------------------===// |
135 | // Block decoration |
136 | //===----------------------------------------------------------------------===// |
137 | |
138 | TEST_F(SerializationTest, ContainsBlockDecoration) { |
139 | auto structType = getFloatStructType(); |
140 | addGlobalVar(structType, "var0" ); |
141 | |
142 | ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); |
143 | |
144 | auto hasBlockDecoration = [](spirv::Opcode opcode, |
145 | ArrayRef<uint32_t> operands) { |
146 | return opcode == spirv::Opcode::OpDecorate && operands.size() == 2 && |
147 | operands[1] == static_cast<uint32_t>(spirv::Decoration::Block); |
148 | }; |
149 | EXPECT_TRUE(scanInstruction(hasBlockDecoration)); |
150 | } |
151 | |
152 | TEST_F(SerializationTest, ContainsNoDuplicatedBlockDecoration) { |
153 | auto structType = getFloatStructType(); |
154 | // Two global variables using the same type should not decorate the type with |
155 | // duplicated `Block` decorations. |
156 | addGlobalVar(structType, "var0" ); |
157 | addGlobalVar(structType, "var1" ); |
158 | |
159 | ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); |
160 | |
161 | unsigned count = 0; |
162 | auto countBlockDecoration = [&count](spirv::Opcode opcode, |
163 | ArrayRef<uint32_t> operands) { |
164 | if (opcode == spirv::Opcode::OpDecorate && operands.size() == 2 && |
165 | operands[1] == static_cast<uint32_t>(spirv::Decoration::Block)) |
166 | ++count; |
167 | return false; |
168 | }; |
169 | ASSERT_FALSE(scanInstruction(countBlockDecoration)); |
170 | EXPECT_EQ(count, 1u); |
171 | } |
172 | |
173 | TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) { |
174 | |
175 | auto signlessInt16Type = |
176 | IntegerType::get(&context, 16, IntegerType::Signless); |
177 | auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed); |
178 | // Check the bit extension of same value under different signedness semantics. |
179 | APInt signlessIntConstVal(signlessInt16Type.getWidth(), -1, |
180 | signlessInt16Type.getSignedness()); |
181 | APInt signedIntConstVal(signedInt16Type.getWidth(), -1, |
182 | signedInt16Type.getSignedness()); |
183 | |
184 | addConstInt(signlessInt16Type, signlessIntConstVal); |
185 | addConstInt(signedInt16Type, signedIntConstVal); |
186 | ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); |
187 | |
188 | auto hasSignlessVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { |
189 | return opcode == spirv::Opcode::OpConstant && operands.size() == 3 && |
190 | operands[2] == 65535; |
191 | }; |
192 | EXPECT_TRUE(scanInstruction(hasSignlessVal)); |
193 | |
194 | auto hasSignedVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { |
195 | return opcode == spirv::Opcode::OpConstant && operands.size() == 3 && |
196 | operands[2] == 4294967295; |
197 | }; |
198 | EXPECT_TRUE(scanInstruction(hasSignedVal)); |
199 | } |
200 | |
201 | TEST_F(SerializationTest, ContainsSymbolName) { |
202 | auto structType = getFloatStructType(); |
203 | addGlobalVar(structType, "var0" ); |
204 | |
205 | spirv::SerializationOptions options; |
206 | options.emitSymbolName = true; |
207 | ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options))); |
208 | |
209 | auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { |
210 | unsigned index = 1; // Skip the result <id> |
211 | return opcode == spirv::Opcode::OpName && |
212 | spirv::decodeStringLiteral(operands, index) == "var0" ; |
213 | }; |
214 | EXPECT_TRUE(scanInstruction(hasVarName)); |
215 | } |
216 | |
217 | TEST_F(SerializationTest, DoesNotContainSymbolName) { |
218 | auto structType = getFloatStructType(); |
219 | addGlobalVar(structType, "var0" ); |
220 | |
221 | spirv::SerializationOptions options; |
222 | options.emitSymbolName = false; |
223 | ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options))); |
224 | |
225 | auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { |
226 | unsigned index = 1; // Skip the result <id> |
227 | return opcode == spirv::Opcode::OpName && |
228 | spirv::decodeStringLiteral(operands, index) == "var0" ; |
229 | }; |
230 | EXPECT_FALSE(scanInstruction(hasVarName)); |
231 | } |
232 | |