1//===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the Deserializer methods for SPIR-V binary instructions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Deserializer.h"
14
15#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/Location.h"
19#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/Debug.h"
23#include <optional>
24
25using namespace mlir;
26
27#define DEBUG_TYPE "spirv-deserialization"
28
29//===----------------------------------------------------------------------===//
30// Utility Functions
31//===----------------------------------------------------------------------===//
32
33/// Extracts the opcode from the given first word of a SPIR-V instruction.
34static inline spirv::Opcode extractOpcode(uint32_t word) {
35 return static_cast<spirv::Opcode>(word & 0xffff);
36}
37
38//===----------------------------------------------------------------------===//
39// Instruction
40//===----------------------------------------------------------------------===//
41
42Value spirv::Deserializer::getValue(uint32_t id) {
43 if (auto constInfo = getConstant(id)) {
44 // Materialize a `spirv.Constant` op at every use site.
45 return opBuilder.create<spirv::ConstantOp>(location: unknownLoc, args&: constInfo->second,
46 args&: constInfo->first);
47 }
48 if (std::optional<std::pair<Attribute, Type>> constCompositeReplicateInfo =
49 getConstantCompositeReplicate(id)) {
50 return opBuilder.create<spirv::EXTConstantCompositeReplicateOp>(
51 location: unknownLoc, args&: constCompositeReplicateInfo->second,
52 args&: constCompositeReplicateInfo->first);
53 }
54 if (auto varOp = getGlobalVariable(id)) {
55 auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
56 location: unknownLoc, args: varOp.getType(), args: SymbolRefAttr::get(symbol: varOp.getOperation()));
57 return addressOfOp.getPointer();
58 }
59 if (auto constOp = getSpecConstant(id)) {
60 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
61 location: unknownLoc, args: constOp.getDefaultValue().getType(),
62 args: SymbolRefAttr::get(symbol: constOp.getOperation()));
63 return referenceOfOp.getReference();
64 }
65 if (SpecConstantCompositeOp specConstCompositeOp =
66 getSpecConstantComposite(id)) {
67 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
68 location: unknownLoc, args: specConstCompositeOp.getType(),
69 args: SymbolRefAttr::get(symbol: specConstCompositeOp.getOperation()));
70 return referenceOfOp.getReference();
71 }
72 if (auto specConstCompositeReplicateOp =
73 getSpecConstantCompositeReplicate(id)) {
74 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
75 location: unknownLoc, args: specConstCompositeReplicateOp.getType(),
76 args: SymbolRefAttr::get(symbol: specConstCompositeReplicateOp.getOperation()));
77 return referenceOfOp.getReference();
78 }
79 if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
80 return materializeSpecConstantOperation(
81 resultID: id, enclosedOpcode: specConstOperationInfo->enclodesOpcode,
82 resultTypeID: specConstOperationInfo->resultTypeID,
83 enclosedOpOperands: specConstOperationInfo->enclosedOpOperands);
84 }
85 if (auto undef = getUndefType(id)) {
86 return opBuilder.create<spirv::UndefOp>(location: unknownLoc, args&: undef);
87 }
88 return valueMap.lookup(Val: id);
89}
90
91LogicalResult spirv::Deserializer::sliceInstruction(
92 spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
93 std::optional<spirv::Opcode> expectedOpcode) {
94 auto binarySize = binary.size();
95 if (curOffset >= binarySize) {
96 return emitError(loc: unknownLoc, message: "expected ")
97 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
98 : "more")
99 << " instruction";
100 }
101
102 // For each instruction, get its word count from the first word to slice it
103 // from the stream properly, and then dispatch to the instruction handler.
104
105 uint32_t wordCount = binary[curOffset] >> 16;
106
107 if (wordCount == 0)
108 return emitError(loc: unknownLoc, message: "word count cannot be zero");
109
110 uint32_t nextOffset = curOffset + wordCount;
111 if (nextOffset > binarySize)
112 return emitError(loc: unknownLoc, message: "insufficient words for the last instruction");
113
114 opcode = extractOpcode(word: binary[curOffset]);
115 operands = binary.slice(N: curOffset + 1, M: wordCount - 1);
116 curOffset = nextOffset;
117 return success();
118}
119
120LogicalResult spirv::Deserializer::processInstruction(
121 spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
122 LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
123 << spirv::stringifyOpcode(opcode) << "\n");
124
125 // First dispatch all the instructions whose opcode does not correspond to
126 // those that have a direct mirror in the SPIR-V dialect
127 switch (opcode) {
128 case spirv::Opcode::OpCapability:
129 return processCapability(operands);
130 case spirv::Opcode::OpExtension:
131 return processExtension(words: operands);
132 case spirv::Opcode::OpExtInst:
133 return processExtInst(operands);
134 case spirv::Opcode::OpExtInstImport:
135 return processExtInstImport(words: operands);
136 case spirv::Opcode::OpMemberName:
137 return processMemberName(words: operands);
138 case spirv::Opcode::OpMemoryModel:
139 return processMemoryModel(operands);
140 case spirv::Opcode::OpEntryPoint:
141 case spirv::Opcode::OpExecutionMode:
142 if (deferInstructions) {
143 deferredInstructions.emplace_back(Args&: opcode, Args&: operands);
144 return success();
145 }
146 break;
147 case spirv::Opcode::OpVariable:
148 if (isa<spirv::ModuleOp>(Val: opBuilder.getBlock()->getParentOp())) {
149 return processGlobalVariable(operands);
150 }
151 break;
152 case spirv::Opcode::OpLine:
153 return processDebugLine(operands);
154 case spirv::Opcode::OpNoLine:
155 clearDebugLine();
156 return success();
157 case spirv::Opcode::OpName:
158 return processName(operands);
159 case spirv::Opcode::OpString:
160 return processDebugString(operands);
161 case spirv::Opcode::OpModuleProcessed:
162 case spirv::Opcode::OpSource:
163 case spirv::Opcode::OpSourceContinued:
164 case spirv::Opcode::OpSourceExtension:
165 // TODO: This is debug information embedded in the binary which should be
166 // translated into the spirv.module.
167 return success();
168 case spirv::Opcode::OpTypeVoid:
169 case spirv::Opcode::OpTypeBool:
170 case spirv::Opcode::OpTypeInt:
171 case spirv::Opcode::OpTypeFloat:
172 case spirv::Opcode::OpTypeVector:
173 case spirv::Opcode::OpTypeMatrix:
174 case spirv::Opcode::OpTypeArray:
175 case spirv::Opcode::OpTypeFunction:
176 case spirv::Opcode::OpTypeImage:
177 case spirv::Opcode::OpTypeSampledImage:
178 case spirv::Opcode::OpTypeRuntimeArray:
179 case spirv::Opcode::OpTypeStruct:
180 case spirv::Opcode::OpTypePointer:
181 case spirv::Opcode::OpTypeTensorARM:
182 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
183 return processType(opcode, operands);
184 case spirv::Opcode::OpTypeForwardPointer:
185 return processTypeForwardPointer(operands);
186 case spirv::Opcode::OpConstant:
187 return processConstant(operands, /*isSpec=*/false);
188 case spirv::Opcode::OpSpecConstant:
189 return processConstant(operands, /*isSpec=*/true);
190 case spirv::Opcode::OpConstantComposite:
191 return processConstantComposite(operands);
192 case spirv::Opcode::OpConstantCompositeReplicateEXT:
193 return processConstantCompositeReplicateEXT(operands);
194 case spirv::Opcode::OpSpecConstantComposite:
195 return processSpecConstantComposite(operands);
196 case spirv::Opcode::OpSpecConstantCompositeReplicateEXT:
197 return processSpecConstantCompositeReplicateEXT(operands);
198 case spirv::Opcode::OpSpecConstantOp:
199 return processSpecConstantOperation(operands);
200 case spirv::Opcode::OpConstantTrue:
201 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
202 case spirv::Opcode::OpSpecConstantTrue:
203 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
204 case spirv::Opcode::OpConstantFalse:
205 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
206 case spirv::Opcode::OpSpecConstantFalse:
207 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
208 case spirv::Opcode::OpConstantNull:
209 return processConstantNull(operands);
210 case spirv::Opcode::OpDecorate:
211 return processDecoration(words: operands);
212 case spirv::Opcode::OpMemberDecorate:
213 return processMemberDecoration(words: operands);
214 case spirv::Opcode::OpFunction:
215 return processFunction(operands);
216 case spirv::Opcode::OpLabel:
217 return processLabel(operands);
218 case spirv::Opcode::OpBranch:
219 return processBranch(operands);
220 case spirv::Opcode::OpBranchConditional:
221 return processBranchConditional(operands);
222 case spirv::Opcode::OpSelectionMerge:
223 return processSelectionMerge(operands);
224 case spirv::Opcode::OpLoopMerge:
225 return processLoopMerge(operands);
226 case spirv::Opcode::OpPhi:
227 return processPhi(operands);
228 case spirv::Opcode::OpUndef:
229 return processUndef(operands);
230 default:
231 break;
232 }
233 return dispatchToAutogenDeserialization(opcode, words: operands);
234}
235
236LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
237 ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
238 unsigned numOperands) {
239 SmallVector<Type, 1> resultTypes;
240 uint32_t valueID = 0;
241
242 size_t wordIndex = 0;
243 if (hasResult) {
244 if (wordIndex >= words.size())
245 return emitError(loc: unknownLoc,
246 message: "expected result type <id> while deserializing for ")
247 << opName;
248
249 // Decode the type <id>
250 auto type = getType(id: words[wordIndex]);
251 if (!type)
252 return emitError(loc: unknownLoc, message: "unknown type result <id>: ")
253 << words[wordIndex];
254 resultTypes.push_back(Elt: type);
255 ++wordIndex;
256
257 // Decode the result <id>
258 if (wordIndex >= words.size())
259 return emitError(loc: unknownLoc,
260 message: "expected result <id> while deserializing for ")
261 << opName;
262 valueID = words[wordIndex];
263 ++wordIndex;
264 }
265
266 SmallVector<Value, 4> operands;
267 SmallVector<NamedAttribute, 4> attributes;
268
269 // Decode operands
270 size_t operandIndex = 0;
271 for (; operandIndex < numOperands && wordIndex < words.size();
272 ++operandIndex, ++wordIndex) {
273 auto arg = getValue(id: words[wordIndex]);
274 if (!arg)
275 return emitError(loc: unknownLoc, message: "unknown result <id>: ") << words[wordIndex];
276 operands.push_back(Elt: arg);
277 }
278 if (operandIndex != numOperands) {
279 return emitError(
280 loc: unknownLoc,
281 message: "found less operands than expected when deserializing for ")
282 << opName << "; only " << operandIndex << " of " << numOperands
283 << " processed";
284 }
285 if (wordIndex != words.size()) {
286 return emitError(
287 loc: unknownLoc,
288 message: "found more operands than expected when deserializing for ")
289 << opName << "; only " << wordIndex << " of " << words.size()
290 << " processed";
291 }
292
293 // Attach attributes from decorations
294 if (decorations.count(Val: valueID)) {
295 auto attrs = decorations[valueID].getAttrs();
296 attributes.append(in_start: attrs.begin(), in_end: attrs.end());
297 }
298
299 // Create the op and update bookkeeping maps
300 Location loc = createFileLineColLoc(opBuilder);
301 OperationState opState(loc, opName);
302 opState.addOperands(newOperands: operands);
303 if (hasResult)
304 opState.addTypes(newTypes: resultTypes);
305 opState.addAttributes(newAttributes: attributes);
306 Operation *op = opBuilder.create(state: opState);
307 if (hasResult)
308 valueMap[valueID] = op->getResult(idx: 0);
309
310 if (op->hasTrait<OpTrait::IsTerminator>())
311 clearDebugLine();
312
313 return success();
314}
315
316LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
317 if (operands.size() != 2) {
318 return emitError(loc: unknownLoc, message: "OpUndef instruction must have two operands");
319 }
320 auto type = getType(id: operands[0]);
321 if (!type) {
322 return emitError(loc: unknownLoc, message: "unknown type <id> with OpUndef instruction");
323 }
324 undefMap[operands[1]] = type;
325 return success();
326}
327
328LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
329 if (operands.size() < 4) {
330 return emitError(loc: unknownLoc,
331 message: "OpExtInst must have at least 4 operands, result type "
332 "<id>, result <id>, set <id> and instruction opcode");
333 }
334 if (!extendedInstSets.count(Val: operands[2])) {
335 return emitError(loc: unknownLoc, message: "undefined set <id> in OpExtInst");
336 }
337 SmallVector<uint32_t, 4> slicedOperands;
338 slicedOperands.append(in_start: operands.begin(), in_end: std::next(x: operands.begin(), n: 2));
339 slicedOperands.append(in_start: std::next(x: operands.begin(), n: 4), in_end: operands.end());
340 return dispatchToExtensionSetAutogenDeserialization(
341 extensionSetName: extendedInstSets[operands[2]], instructionID: operands[3], words: slicedOperands);
342}
343
344namespace mlir {
345namespace spirv {
346
347template <>
348LogicalResult
349Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
350 unsigned wordIndex = 0;
351 if (wordIndex >= words.size()) {
352 return emitError(loc: unknownLoc,
353 message: "missing Execution Model specification in OpEntryPoint");
354 }
355 auto execModel = spirv::ExecutionModelAttr::get(
356 context, value: static_cast<spirv::ExecutionModel>(words[wordIndex++]));
357 if (wordIndex >= words.size()) {
358 return emitError(loc: unknownLoc, message: "missing <id> in OpEntryPoint");
359 }
360 // Get the function <id>
361 auto fnID = words[wordIndex++];
362 // Get the function name
363 auto fnName = decodeStringLiteral(words, wordIndex);
364 // Verify that the function <id> matches the fnName
365 auto parsedFunc = getFunction(id: fnID);
366 if (!parsedFunc) {
367 return emitError(loc: unknownLoc, message: "no function matching <id> ") << fnID;
368 }
369 if (parsedFunc.getName() != fnName) {
370 // The deserializer uses "spirv_fn_<id>" as the function name if the input
371 // SPIR-V blob does not contain a name for it. We should use a more clear
372 // indication for such case rather than relying on naming details.
373 if (!parsedFunc.getName().starts_with(Prefix: "spirv_fn_"))
374 return emitError(loc: unknownLoc,
375 message: "function name mismatch between OpEntryPoint "
376 "and OpFunction with <id> ")
377 << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
378 parsedFunc.setName(fnName);
379 }
380 SmallVector<Attribute, 4> interface;
381 while (wordIndex < words.size()) {
382 auto arg = getGlobalVariable(id: words[wordIndex]);
383 if (!arg) {
384 return emitError(loc: unknownLoc, message: "undefined result <id> ")
385 << words[wordIndex] << " while decoding OpEntryPoint";
386 }
387 interface.push_back(Elt: SymbolRefAttr::get(symbol: arg.getOperation()));
388 wordIndex++;
389 }
390 opBuilder.create<spirv::EntryPointOp>(
391 location: unknownLoc, args&: execModel, args: SymbolRefAttr::get(ctx: opBuilder.getContext(), value: fnName),
392 args: opBuilder.getArrayAttr(value: interface));
393 return success();
394}
395
396template <>
397LogicalResult
398Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
399 unsigned wordIndex = 0;
400 if (wordIndex >= words.size()) {
401 return emitError(loc: unknownLoc,
402 message: "missing function result <id> in OpExecutionMode");
403 }
404 // Get the function <id> to get the name of the function
405 auto fnID = words[wordIndex++];
406 auto fn = getFunction(id: fnID);
407 if (!fn) {
408 return emitError(loc: unknownLoc, message: "no function matching <id> ") << fnID;
409 }
410 // Get the Execution mode
411 if (wordIndex >= words.size()) {
412 return emitError(loc: unknownLoc, message: "missing Execution Mode in OpExecutionMode");
413 }
414 auto execMode = spirv::ExecutionModeAttr::get(
415 context, value: static_cast<spirv::ExecutionMode>(words[wordIndex++]));
416
417 // Get the values
418 SmallVector<Attribute, 4> attrListElems;
419 while (wordIndex < words.size()) {
420 attrListElems.push_back(Elt: opBuilder.getI32IntegerAttr(value: words[wordIndex++]));
421 }
422 auto values = opBuilder.getArrayAttr(value: attrListElems);
423 opBuilder.create<spirv::ExecutionModeOp>(
424 location: unknownLoc, args: SymbolRefAttr::get(ctx: opBuilder.getContext(), value: fn.getName()),
425 args&: execMode, args&: values);
426 return success();
427}
428
429template <>
430LogicalResult
431Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
432 if (operands.size() < 3) {
433 return emitError(loc: unknownLoc,
434 message: "OpFunctionCall must have at least 3 operands");
435 }
436
437 Type resultType = getType(id: operands[0]);
438 if (!resultType) {
439 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
440 << operands[0];
441 }
442
443 // Use null type to mean no result type.
444 if (isVoidType(type: resultType))
445 resultType = nullptr;
446
447 auto resultID = operands[1];
448 auto functionID = operands[2];
449
450 auto functionName = getFunctionSymbol(id: functionID);
451
452 SmallVector<Value, 4> arguments;
453 for (auto operand : llvm::drop_begin(RangeOrContainer&: operands, N: 3)) {
454 auto value = getValue(id: operand);
455 if (!value) {
456 return emitError(loc: unknownLoc, message: "unknown <id> ")
457 << operand << " used by OpFunctionCall";
458 }
459 arguments.push_back(Elt: value);
460 }
461
462 auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
463 location: unknownLoc, args&: resultType,
464 args: SymbolRefAttr::get(ctx: opBuilder.getContext(), value: functionName), args&: arguments);
465
466 if (resultType)
467 valueMap[resultID] = opFunctionCall.getResult(i: 0);
468 return success();
469}
470
471template <>
472LogicalResult
473Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
474 SmallVector<Type, 1> resultTypes;
475 size_t wordIndex = 0;
476 SmallVector<Value, 4> operands;
477 SmallVector<NamedAttribute, 4> attributes;
478
479 if (wordIndex < words.size()) {
480 auto arg = getValue(id: words[wordIndex]);
481
482 if (!arg) {
483 return emitError(loc: unknownLoc, message: "unknown result <id> : ")
484 << words[wordIndex];
485 }
486
487 operands.push_back(Elt: arg);
488 wordIndex++;
489 }
490
491 if (wordIndex < words.size()) {
492 auto arg = getValue(id: words[wordIndex]);
493
494 if (!arg) {
495 return emitError(loc: unknownLoc, message: "unknown result <id> : ")
496 << words[wordIndex];
497 }
498
499 operands.push_back(Elt: arg);
500 wordIndex++;
501 }
502
503 bool isAlignedAttr = false;
504
505 if (wordIndex < words.size()) {
506 auto attrValue = words[wordIndex++];
507 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
508 args: static_cast<spirv::MemoryAccess>(attrValue));
509 attributes.push_back(
510 Elt: opBuilder.getNamedAttr(name: attributeName<MemoryAccess>(), val: attr));
511 isAlignedAttr = (attrValue == 2);
512 }
513
514 if (isAlignedAttr && wordIndex < words.size()) {
515 attributes.push_back(Elt: opBuilder.getNamedAttr(
516 name: "alignment", val: opBuilder.getI32IntegerAttr(value: words[wordIndex++])));
517 }
518
519 if (wordIndex < words.size()) {
520 auto attrValue = words[wordIndex++];
521 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
522 args: static_cast<spirv::MemoryAccess>(attrValue));
523 attributes.push_back(Elt: opBuilder.getNamedAttr(name: "source_memory_access", val: attr));
524 }
525
526 if (wordIndex < words.size()) {
527 attributes.push_back(Elt: opBuilder.getNamedAttr(
528 name: "source_alignment", val: opBuilder.getI32IntegerAttr(value: words[wordIndex++])));
529 }
530
531 if (wordIndex != words.size()) {
532 return emitError(loc: unknownLoc,
533 message: "found more operands than expected when deserializing "
534 "spirv::CopyMemoryOp, only ")
535 << wordIndex << " of " << words.size() << " processed";
536 }
537
538 Location loc = createFileLineColLoc(opBuilder);
539 opBuilder.create<spirv::CopyMemoryOp>(location: loc, args&: resultTypes, args&: operands, args&: attributes);
540
541 return success();
542}
543
544template <>
545LogicalResult Deserializer::processOp<spirv::GenericCastToPtrExplicitOp>(
546 ArrayRef<uint32_t> words) {
547 if (words.size() != 4) {
548 return emitError(loc: unknownLoc,
549 message: "expected 4 words in GenericCastToPtrExplicitOp"
550 " but got : ")
551 << words.size();
552 }
553 SmallVector<Type, 1> resultTypes;
554 SmallVector<Value, 4> operands;
555 uint32_t valueID = 0;
556 auto type = getType(id: words[0]);
557
558 if (!type)
559 return emitError(loc: unknownLoc, message: "unknown type result <id> : ") << words[0];
560 resultTypes.push_back(Elt: type);
561
562 valueID = words[1];
563
564 auto arg = getValue(id: words[2]);
565 if (!arg)
566 return emitError(loc: unknownLoc, message: "unknown result <id> : ") << words[2];
567 operands.push_back(Elt: arg);
568
569 Location loc = createFileLineColLoc(opBuilder);
570 Operation *op = opBuilder.create<spirv::GenericCastToPtrExplicitOp>(
571 location: loc, args&: resultTypes, args&: operands);
572 valueMap[valueID] = op->getResult(idx: 0);
573 return success();
574}
575
576// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
577// various Deserializer::processOp<...>() specializations.
578#define GET_DESERIALIZATION_FNS
579#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
580
581} // namespace spirv
582} // namespace mlir
583

source code of mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp