1//===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the Deserializer methods for SPIR-V binary instructions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Deserializer.h"
14
15#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/Location.h"
19#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/Debug.h"
23#include <optional>
24
25using namespace mlir;
26
27#define DEBUG_TYPE "spirv-deserialization"
28
29//===----------------------------------------------------------------------===//
30// Utility Functions
31//===----------------------------------------------------------------------===//
32
33/// Extracts the opcode from the given first word of a SPIR-V instruction.
34static inline spirv::Opcode extractOpcode(uint32_t word) {
35 return static_cast<spirv::Opcode>(word & 0xffff);
36}
37
38//===----------------------------------------------------------------------===//
39// Instruction
40//===----------------------------------------------------------------------===//
41
42Value spirv::Deserializer::getValue(uint32_t id) {
43 if (auto constInfo = getConstant(id)) {
44 // Materialize a `spirv.Constant` op at every use site.
45 return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
46 constInfo->first);
47 }
48 if (auto varOp = getGlobalVariable(id)) {
49 auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
50 unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
51 return addressOfOp.getPointer();
52 }
53 if (auto constOp = getSpecConstant(id)) {
54 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
55 unknownLoc, constOp.getDefaultValue().getType(),
56 SymbolRefAttr::get(constOp.getOperation()));
57 return referenceOfOp.getReference();
58 }
59 if (auto constCompositeOp = getSpecConstantComposite(id)) {
60 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
61 unknownLoc, constCompositeOp.getType(),
62 SymbolRefAttr::get(constCompositeOp.getOperation()));
63 return referenceOfOp.getReference();
64 }
65 if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
66 return materializeSpecConstantOperation(
67 id, specConstOperationInfo->enclodesOpcode,
68 specConstOperationInfo->resultTypeID,
69 specConstOperationInfo->enclosedOpOperands);
70 }
71 if (auto undef = getUndefType(id)) {
72 return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
73 }
74 return valueMap.lookup(Val: id);
75}
76
77LogicalResult spirv::Deserializer::sliceInstruction(
78 spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
79 std::optional<spirv::Opcode> expectedOpcode) {
80 auto binarySize = binary.size();
81 if (curOffset >= binarySize) {
82 return emitError(unknownLoc, "expected ")
83 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
84 : "more")
85 << " instruction";
86 }
87
88 // For each instruction, get its word count from the first word to slice it
89 // from the stream properly, and then dispatch to the instruction handler.
90
91 uint32_t wordCount = binary[curOffset] >> 16;
92
93 if (wordCount == 0)
94 return emitError(loc: unknownLoc, message: "word count cannot be zero");
95
96 uint32_t nextOffset = curOffset + wordCount;
97 if (nextOffset > binarySize)
98 return emitError(loc: unknownLoc, message: "insufficient words for the last instruction");
99
100 opcode = extractOpcode(binary[curOffset]);
101 operands = binary.slice(N: curOffset + 1, M: wordCount - 1);
102 curOffset = nextOffset;
103 return success();
104}
105
106LogicalResult spirv::Deserializer::processInstruction(
107 spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
108 LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
109 << spirv::stringifyOpcode(opcode) << "\n");
110
111 // First dispatch all the instructions whose opcode does not correspond to
112 // those that have a direct mirror in the SPIR-V dialect
113 switch (opcode) {
114 case spirv::Opcode::OpCapability:
115 return processCapability(operands);
116 case spirv::Opcode::OpExtension:
117 return processExtension(words: operands);
118 case spirv::Opcode::OpExtInst:
119 return processExtInst(operands);
120 case spirv::Opcode::OpExtInstImport:
121 return processExtInstImport(words: operands);
122 case spirv::Opcode::OpMemberName:
123 return processMemberName(words: operands);
124 case spirv::Opcode::OpMemoryModel:
125 return processMemoryModel(operands);
126 case spirv::Opcode::OpEntryPoint:
127 case spirv::Opcode::OpExecutionMode:
128 if (deferInstructions) {
129 deferredInstructions.emplace_back(opcode, operands);
130 return success();
131 }
132 break;
133 case spirv::Opcode::OpVariable:
134 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
135 return processGlobalVariable(operands);
136 }
137 break;
138 case spirv::Opcode::OpLine:
139 return processDebugLine(operands);
140 case spirv::Opcode::OpNoLine:
141 clearDebugLine();
142 return success();
143 case spirv::Opcode::OpName:
144 return processName(operands);
145 case spirv::Opcode::OpString:
146 return processDebugString(operands);
147 case spirv::Opcode::OpModuleProcessed:
148 case spirv::Opcode::OpSource:
149 case spirv::Opcode::OpSourceContinued:
150 case spirv::Opcode::OpSourceExtension:
151 // TODO: This is debug information embedded in the binary which should be
152 // translated into the spirv.module.
153 return success();
154 case spirv::Opcode::OpTypeVoid:
155 case spirv::Opcode::OpTypeBool:
156 case spirv::Opcode::OpTypeInt:
157 case spirv::Opcode::OpTypeFloat:
158 case spirv::Opcode::OpTypeVector:
159 case spirv::Opcode::OpTypeMatrix:
160 case spirv::Opcode::OpTypeArray:
161 case spirv::Opcode::OpTypeFunction:
162 case spirv::Opcode::OpTypeImage:
163 case spirv::Opcode::OpTypeSampledImage:
164 case spirv::Opcode::OpTypeRuntimeArray:
165 case spirv::Opcode::OpTypeStruct:
166 case spirv::Opcode::OpTypePointer:
167 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
168 return processType(opcode, operands);
169 case spirv::Opcode::OpTypeForwardPointer:
170 return processTypeForwardPointer(operands);
171 case spirv::Opcode::OpTypeJointMatrixINTEL:
172 return processType(opcode, operands);
173 case spirv::Opcode::OpConstant:
174 return processConstant(operands, /*isSpec=*/false);
175 case spirv::Opcode::OpSpecConstant:
176 return processConstant(operands, /*isSpec=*/true);
177 case spirv::Opcode::OpConstantComposite:
178 return processConstantComposite(operands);
179 case spirv::Opcode::OpSpecConstantComposite:
180 return processSpecConstantComposite(operands);
181 case spirv::Opcode::OpSpecConstantOp:
182 return processSpecConstantOperation(operands);
183 case spirv::Opcode::OpConstantTrue:
184 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
185 case spirv::Opcode::OpSpecConstantTrue:
186 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
187 case spirv::Opcode::OpConstantFalse:
188 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
189 case spirv::Opcode::OpSpecConstantFalse:
190 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
191 case spirv::Opcode::OpConstantNull:
192 return processConstantNull(operands);
193 case spirv::Opcode::OpDecorate:
194 return processDecoration(words: operands);
195 case spirv::Opcode::OpMemberDecorate:
196 return processMemberDecoration(words: operands);
197 case spirv::Opcode::OpFunction:
198 return processFunction(operands);
199 case spirv::Opcode::OpLabel:
200 return processLabel(operands);
201 case spirv::Opcode::OpBranch:
202 return processBranch(operands);
203 case spirv::Opcode::OpBranchConditional:
204 return processBranchConditional(operands);
205 case spirv::Opcode::OpSelectionMerge:
206 return processSelectionMerge(operands);
207 case spirv::Opcode::OpLoopMerge:
208 return processLoopMerge(operands);
209 case spirv::Opcode::OpPhi:
210 return processPhi(operands);
211 case spirv::Opcode::OpUndef:
212 return processUndef(operands);
213 default:
214 break;
215 }
216 return dispatchToAutogenDeserialization(opcode, operands);
217}
218
219LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
220 ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
221 unsigned numOperands) {
222 SmallVector<Type, 1> resultTypes;
223 uint32_t valueID = 0;
224
225 size_t wordIndex = 0;
226 if (hasResult) {
227 if (wordIndex >= words.size())
228 return emitError(loc: unknownLoc,
229 message: "expected result type <id> while deserializing for ")
230 << opName;
231
232 // Decode the type <id>
233 auto type = getType(id: words[wordIndex]);
234 if (!type)
235 return emitError(loc: unknownLoc, message: "unknown type result <id>: ")
236 << words[wordIndex];
237 resultTypes.push_back(Elt: type);
238 ++wordIndex;
239
240 // Decode the result <id>
241 if (wordIndex >= words.size())
242 return emitError(loc: unknownLoc,
243 message: "expected result <id> while deserializing for ")
244 << opName;
245 valueID = words[wordIndex];
246 ++wordIndex;
247 }
248
249 SmallVector<Value, 4> operands;
250 SmallVector<NamedAttribute, 4> attributes;
251
252 // Decode operands
253 size_t operandIndex = 0;
254 for (; operandIndex < numOperands && wordIndex < words.size();
255 ++operandIndex, ++wordIndex) {
256 auto arg = getValue(id: words[wordIndex]);
257 if (!arg)
258 return emitError(loc: unknownLoc, message: "unknown result <id>: ") << words[wordIndex];
259 operands.push_back(Elt: arg);
260 }
261 if (operandIndex != numOperands) {
262 return emitError(
263 loc: unknownLoc,
264 message: "found less operands than expected when deserializing for ")
265 << opName << "; only " << operandIndex << " of " << numOperands
266 << " processed";
267 }
268 if (wordIndex != words.size()) {
269 return emitError(
270 loc: unknownLoc,
271 message: "found more operands than expected when deserializing for ")
272 << opName << "; only " << wordIndex << " of " << words.size()
273 << " processed";
274 }
275
276 // Attach attributes from decorations
277 if (decorations.count(Val: valueID)) {
278 auto attrs = decorations[valueID].getAttrs();
279 attributes.append(in_start: attrs.begin(), in_end: attrs.end());
280 }
281
282 // Create the op and update bookkeeping maps
283 Location loc = createFileLineColLoc(opBuilder);
284 OperationState opState(loc, opName);
285 opState.addOperands(newOperands: operands);
286 if (hasResult)
287 opState.addTypes(newTypes: resultTypes);
288 opState.addAttributes(newAttributes: attributes);
289 Operation *op = opBuilder.create(state: opState);
290 if (hasResult)
291 valueMap[valueID] = op->getResult(idx: 0);
292
293 if (op->hasTrait<OpTrait::IsTerminator>())
294 clearDebugLine();
295
296 return success();
297}
298
299LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
300 if (operands.size() != 2) {
301 return emitError(loc: unknownLoc, message: "OpUndef instruction must have two operands");
302 }
303 auto type = getType(id: operands[0]);
304 if (!type) {
305 return emitError(loc: unknownLoc, message: "unknown type <id> with OpUndef instruction");
306 }
307 undefMap[operands[1]] = type;
308 return success();
309}
310
311LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
312 if (operands.size() < 4) {
313 return emitError(loc: unknownLoc,
314 message: "OpExtInst must have at least 4 operands, result type "
315 "<id>, result <id>, set <id> and instruction opcode");
316 }
317 if (!extendedInstSets.count(Val: operands[2])) {
318 return emitError(loc: unknownLoc, message: "undefined set <id> in OpExtInst");
319 }
320 SmallVector<uint32_t, 4> slicedOperands;
321 slicedOperands.append(in_start: operands.begin(), in_end: std::next(x: operands.begin(), n: 2));
322 slicedOperands.append(in_start: std::next(x: operands.begin(), n: 4), in_end: operands.end());
323 return dispatchToExtensionSetAutogenDeserialization(
324 extensionSetName: extendedInstSets[operands[2]], instructionID: operands[3], words: slicedOperands);
325}
326
327namespace mlir {
328namespace spirv {
329
330template <>
331LogicalResult
332Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
333 unsigned wordIndex = 0;
334 if (wordIndex >= words.size()) {
335 return emitError(unknownLoc,
336 "missing Execution Model specification in OpEntryPoint");
337 }
338 auto execModel = spirv::ExecutionModelAttr::get(
339 context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
340 if (wordIndex >= words.size()) {
341 return emitError(unknownLoc, "missing <id> in OpEntryPoint");
342 }
343 // Get the function <id>
344 auto fnID = words[wordIndex++];
345 // Get the function name
346 auto fnName = decodeStringLiteral(words, wordIndex);
347 // Verify that the function <id> matches the fnName
348 auto parsedFunc = getFunction(fnID);
349 if (!parsedFunc) {
350 return emitError(unknownLoc, "no function matching <id> ") << fnID;
351 }
352 if (parsedFunc.getName() != fnName) {
353 // The deserializer uses "spirv_fn_<id>" as the function name if the input
354 // SPIR-V blob does not contain a name for it. We should use a more clear
355 // indication for such case rather than relying on naming details.
356 if (!parsedFunc.getName().starts_with("spirv_fn_"))
357 return emitError(unknownLoc,
358 "function name mismatch between OpEntryPoint "
359 "and OpFunction with <id> ")
360 << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
361 parsedFunc.setName(fnName);
362 }
363 SmallVector<Attribute, 4> interface;
364 while (wordIndex < words.size()) {
365 auto arg = getGlobalVariable(words[wordIndex]);
366 if (!arg) {
367 return emitError(unknownLoc, "undefined result <id> ")
368 << words[wordIndex] << " while decoding OpEntryPoint";
369 }
370 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
371 wordIndex++;
372 }
373 opBuilder.create<spirv::EntryPointOp>(
374 unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
375 opBuilder.getArrayAttr(interface));
376 return success();
377}
378
379template <>
380LogicalResult
381Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
382 unsigned wordIndex = 0;
383 if (wordIndex >= words.size()) {
384 return emitError(unknownLoc,
385 "missing function result <id> in OpExecutionMode");
386 }
387 // Get the function <id> to get the name of the function
388 auto fnID = words[wordIndex++];
389 auto fn = getFunction(fnID);
390 if (!fn) {
391 return emitError(unknownLoc, "no function matching <id> ") << fnID;
392 }
393 // Get the Execution mode
394 if (wordIndex >= words.size()) {
395 return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
396 }
397 auto execMode = spirv::ExecutionModeAttr::get(
398 context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
399
400 // Get the values
401 SmallVector<Attribute, 4> attrListElems;
402 while (wordIndex < words.size()) {
403 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
404 }
405 auto values = opBuilder.getArrayAttr(attrListElems);
406 opBuilder.create<spirv::ExecutionModeOp>(
407 unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
408 execMode, values);
409 return success();
410}
411
412template <>
413LogicalResult
414Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
415 if (operands.size() < 3) {
416 return emitError(unknownLoc,
417 "OpFunctionCall must have at least 3 operands");
418 }
419
420 Type resultType = getType(operands[0]);
421 if (!resultType) {
422 return emitError(unknownLoc, "undefined result type from <id> ")
423 << operands[0];
424 }
425
426 // Use null type to mean no result type.
427 if (isVoidType(resultType))
428 resultType = nullptr;
429
430 auto resultID = operands[1];
431 auto functionID = operands[2];
432
433 auto functionName = getFunctionSymbol(functionID);
434
435 SmallVector<Value, 4> arguments;
436 for (auto operand : llvm::drop_begin(operands, 3)) {
437 auto value = getValue(operand);
438 if (!value) {
439 return emitError(unknownLoc, "unknown <id> ")
440 << operand << " used by OpFunctionCall";
441 }
442 arguments.push_back(value);
443 }
444
445 auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
446 unknownLoc, resultType,
447 SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
448
449 if (resultType)
450 valueMap[resultID] = opFunctionCall.getResult(0);
451 return success();
452}
453
454template <>
455LogicalResult
456Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
457 SmallVector<Type, 1> resultTypes;
458 size_t wordIndex = 0;
459 SmallVector<Value, 4> operands;
460 SmallVector<NamedAttribute, 4> attributes;
461
462 if (wordIndex < words.size()) {
463 auto arg = getValue(words[wordIndex]);
464
465 if (!arg) {
466 return emitError(unknownLoc, "unknown result <id> : ")
467 << words[wordIndex];
468 }
469
470 operands.push_back(arg);
471 wordIndex++;
472 }
473
474 if (wordIndex < words.size()) {
475 auto arg = getValue(words[wordIndex]);
476
477 if (!arg) {
478 return emitError(unknownLoc, "unknown result <id> : ")
479 << words[wordIndex];
480 }
481
482 operands.push_back(arg);
483 wordIndex++;
484 }
485
486 bool isAlignedAttr = false;
487
488 if (wordIndex < words.size()) {
489 auto attrValue = words[wordIndex++];
490 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
491 static_cast<spirv::MemoryAccess>(attrValue));
492 attributes.push_back(
493 opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr));
494 isAlignedAttr = (attrValue == 2);
495 }
496
497 if (isAlignedAttr && wordIndex < words.size()) {
498 attributes.push_back(opBuilder.getNamedAttr(
499 "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
500 }
501
502 if (wordIndex < words.size()) {
503 auto attrValue = words[wordIndex++];
504 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
505 static_cast<spirv::MemoryAccess>(attrValue));
506 attributes.push_back(opBuilder.getNamedAttr("source_memory_access", attr));
507 }
508
509 if (wordIndex < words.size()) {
510 attributes.push_back(opBuilder.getNamedAttr(
511 "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
512 }
513
514 if (wordIndex != words.size()) {
515 return emitError(unknownLoc,
516 "found more operands than expected when deserializing "
517 "spirv::CopyMemoryOp, only ")
518 << wordIndex << " of " << words.size() << " processed";
519 }
520
521 Location loc = createFileLineColLoc(opBuilder);
522 opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
523
524 return success();
525}
526
527template <>
528LogicalResult Deserializer::processOp<spirv::GenericCastToPtrExplicitOp>(
529 ArrayRef<uint32_t> words) {
530 if (words.size() != 4) {
531 return emitError(unknownLoc,
532 "expected 4 words in GenericCastToPtrExplicitOp"
533 " but got : ")
534 << words.size();
535 }
536 SmallVector<Type, 1> resultTypes;
537 SmallVector<Value, 4> operands;
538 uint32_t valueID = 0;
539 auto type = getType(words[0]);
540
541 if (!type)
542 return emitError(unknownLoc, "unknown type result <id> : ") << words[0];
543 resultTypes.push_back(type);
544
545 valueID = words[1];
546
547 auto arg = getValue(words[2]);
548 if (!arg)
549 return emitError(unknownLoc, "unknown result <id> : ") << words[2];
550 operands.push_back(arg);
551
552 Location loc = createFileLineColLoc(opBuilder);
553 Operation *op = opBuilder.create<spirv::GenericCastToPtrExplicitOp>(
554 loc, resultTypes, operands);
555 valueMap[valueID] = op->getResult(0);
556 return success();
557}
558
559// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
560// various Deserializer::processOp<...>() specializations.
561#define GET_DESERIALIZATION_FNS
562#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
563
564} // namespace spirv
565} // namespace mlir
566

source code of mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp