1//===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Deserializer.h"
14
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
22#include "mlir/Support/LogicalResult.h"
23#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/Sequence.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/ADT/StringExtras.h"
28#include "llvm/ADT/bit.h"
29#include "llvm/Support/Debug.h"
30#include "llvm/Support/SaveAndRestore.h"
31#include "llvm/Support/raw_ostream.h"
32#include <optional>
33
34using namespace mlir;
35
36#define DEBUG_TYPE "spirv-deserialization"
37
38//===----------------------------------------------------------------------===//
39// Utility Functions
40//===----------------------------------------------------------------------===//
41
42/// Returns true if the given `block` is a function entry block.
43static inline bool isFnEntryBlock(Block *block) {
44 return block->isEntryBlock() &&
45 isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
46}
47
48//===----------------------------------------------------------------------===//
49// Deserializer Method Definitions
50//===----------------------------------------------------------------------===//
51
52spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
53 MLIRContext *context)
54 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion())
56#ifndef NDEBUG
57 ,
58 logger(llvm::dbgs())
59#endif
60{
61}
62
63LogicalResult spirv::Deserializer::deserialize() {
64 LLVM_DEBUG({
65 logger.resetIndent();
66 logger.startLine()
67 << "//+++---------- start deserialization ----------+++//\n";
68 });
69
70 if (failed(result: processHeader()))
71 return failure();
72
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
77 // Slice the next instruction out and populate `opcode` and `operands`.
78 // Internally this also updates `curOffset`.
79 if (failed(sliceInstruction(opcode, operands)))
80 return failure();
81
82 if (failed(processInstruction(opcode, operands)))
83 return failure();
84 }
85
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
88
89 for (auto &deferred : deferredInstructions) {
90 if (failed(processInstruction(deferred.first, deferred.second, false))) {
91 return failure();
92 }
93 }
94
95 attachVCETriple();
96
97 LLVM_DEBUG(logger.startLine()
98 << "//+++-------- completed deserialization --------+++//\n");
99 return success();
100}
101
102OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
103 return std::move(module);
104}
105
106//===----------------------------------------------------------------------===//
107// Module structure
108//===----------------------------------------------------------------------===//
109
110OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111 OpBuilder builder(context);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
114 return cast<spirv::ModuleOp>(Operation::create(state));
115}
116
117LogicalResult spirv::Deserializer::processHeader() {
118 if (binary.size() < spirv::kHeaderWordCount)
119 return emitError(loc: unknownLoc,
120 message: "SPIR-V binary module must have a 5-word header");
121
122 if (binary[0] != spirv::kMagicNumber)
123 return emitError(loc: unknownLoc, message: "incorrect magic number");
124
125 // Version number bytes: 0 | major number | minor number | 0
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130#define MIN_VERSION_CASE(v) \
131 case v: \
132 version = spirv::Version::V_1_##v; \
133 break
134
135 MIN_VERSION_CASE(0);
136 MIN_VERSION_CASE(1);
137 MIN_VERSION_CASE(2);
138 MIN_VERSION_CASE(3);
139 MIN_VERSION_CASE(4);
140 MIN_VERSION_CASE(5);
141#undef MIN_VERSION_CASE
142 default:
143 return emitError(loc: unknownLoc, message: "unsupported SPIR-V minor version: ")
144 << minorVersion;
145 }
146 } else {
147 return emitError(loc: unknownLoc, message: "unsupported SPIR-V major version: ")
148 << majorVersion;
149 }
150
151 // TODO: generator number, bound, schema
152 curOffset = spirv::kHeaderWordCount;
153 return success();
154}
155
156LogicalResult
157spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
158 if (operands.size() != 1)
159 return emitError(loc: unknownLoc, message: "OpMemoryModel must have one parameter");
160
161 auto cap = spirv::symbolizeCapability(operands[0]);
162 if (!cap)
163 return emitError(loc: unknownLoc, message: "unknown capability: ") << operands[0];
164
165 capabilities.insert(*cap);
166 return success();
167}
168
169LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
170 if (words.empty()) {
171 return emitError(
172 loc: unknownLoc,
173 message: "OpExtension must have a literal string for the extension name");
174 }
175
176 unsigned wordIndex = 0;
177 StringRef extName = decodeStringLiteral(words, wordIndex);
178 if (wordIndex != words.size())
179 return emitError(loc: unknownLoc,
180 message: "unexpected trailing words in OpExtension instruction");
181 auto ext = spirv::symbolizeExtension(extName);
182 if (!ext)
183 return emitError(loc: unknownLoc, message: "unknown extension: ") << extName;
184
185 extensions.insert(*ext);
186 return success();
187}
188
189LogicalResult
190spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
191 if (words.size() < 2) {
192 return emitError(loc: unknownLoc,
193 message: "OpExtInstImport must have a result <id> and a literal "
194 "string for the extended instruction set name");
195 }
196
197 unsigned wordIndex = 1;
198 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
199 if (wordIndex != words.size()) {
200 return emitError(loc: unknownLoc,
201 message: "unexpected trailing words in OpExtInstImport");
202 }
203 return success();
204}
205
206void spirv::Deserializer::attachVCETriple() {
207 (*module)->setAttr(
208 spirv::ModuleOp::getVCETripleAttrName(),
209 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
210 extensions.getArrayRef(), context));
211}
212
213LogicalResult
214spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
215 if (operands.size() != 2)
216 return emitError(loc: unknownLoc, message: "OpMemoryModel must have two operands");
217
218 (*module)->setAttr(
219 module->getAddressingModelAttrName(),
220 opBuilder.getAttr<spirv::AddressingModelAttr>(
221 static_cast<spirv::AddressingModel>(operands.front())));
222
223 (*module)->setAttr(module->getMemoryModelAttrName(),
224 opBuilder.getAttr<spirv::MemoryModelAttr>(
225 static_cast<spirv::MemoryModel>(operands.back())));
226
227 return success();
228}
229
230LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
231 // TODO: This function should also be auto-generated. For now, since only a
232 // few decorations are processed/handled in a meaningful manner, going with a
233 // manual implementation.
234 if (words.size() < 2) {
235 return emitError(
236 loc: unknownLoc, message: "OpDecorate must have at least result <id> and Decoration");
237 }
238 auto decorationName =
239 stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
240 if (decorationName.empty()) {
241 return emitError(loc: unknownLoc, message: "invalid Decoration code : ") << words[1];
242 }
243 auto symbol = getSymbolDecoration(decorationName: decorationName);
244 switch (static_cast<spirv::Decoration>(words[1])) {
245 case spirv::Decoration::FPFastMathMode:
246 if (words.size() != 3) {
247 return emitError(loc: unknownLoc, message: "OpDecorate with ")
248 << decorationName << " needs a single integer literal";
249 }
250 decorations[words[0]].set(
251 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
252 static_cast<FPFastMathMode>(words[2])));
253 break;
254 case spirv::Decoration::DescriptorSet:
255 case spirv::Decoration::Binding:
256 if (words.size() != 3) {
257 return emitError(loc: unknownLoc, message: "OpDecorate with ")
258 << decorationName << " needs a single integer literal";
259 }
260 decorations[words[0]].set(
261 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
262 break;
263 case spirv::Decoration::BuiltIn:
264 if (words.size() != 3) {
265 return emitError(loc: unknownLoc, message: "OpDecorate with ")
266 << decorationName << " needs a single integer literal";
267 }
268 decorations[words[0]].set(
269 symbol, opBuilder.getStringAttr(
270 bytes: stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
271 break;
272 case spirv::Decoration::ArrayStride:
273 if (words.size() != 3) {
274 return emitError(loc: unknownLoc, message: "OpDecorate with ")
275 << decorationName << " needs a single integer literal";
276 }
277 typeDecorations[words[0]] = words[2];
278 break;
279 case spirv::Decoration::LinkageAttributes: {
280 if (words.size() < 4) {
281 return emitError(loc: unknownLoc, message: "OpDecorate with ")
282 << decorationName
283 << " needs at least 1 string and 1 integer literal";
284 }
285 // LinkageAttributes has two parameters ["linkageName", linkageType]
286 // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
287 // "linkageName" is a stringliteral encoded as uint32_t,
288 // hence the size of name is variable length which results in words.size()
289 // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
290 // 3 + ceildiv(strlen(name), 4).
291 unsigned wordIndex = 2;
292 auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
293 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
294 static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
295 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
296 StringAttr::get(context, linkageName), linkageTypeAttr);
297 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
298 break;
299 }
300 case spirv::Decoration::Aliased:
301 case spirv::Decoration::AliasedPointer:
302 case spirv::Decoration::Block:
303 case spirv::Decoration::BufferBlock:
304 case spirv::Decoration::Flat:
305 case spirv::Decoration::NonReadable:
306 case spirv::Decoration::NonWritable:
307 case spirv::Decoration::NoPerspective:
308 case spirv::Decoration::NoSignedWrap:
309 case spirv::Decoration::NoUnsignedWrap:
310 case spirv::Decoration::RelaxedPrecision:
311 case spirv::Decoration::Restrict:
312 case spirv::Decoration::RestrictPointer:
313 case spirv::Decoration::NoContraction:
314 if (words.size() != 2) {
315 return emitError(loc: unknownLoc, message: "OpDecoration with ")
316 << decorationName << "needs a single target <id>";
317 }
318 // Block decoration does not affect spirv.struct type, but is still stored
319 // for verification.
320 // TODO: Update StructType to contain this information since
321 // it is needed for many validation rules.
322 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
323 break;
324 case spirv::Decoration::Location:
325 case spirv::Decoration::SpecId:
326 if (words.size() != 3) {
327 return emitError(loc: unknownLoc, message: "OpDecoration with ")
328 << decorationName << "needs a single integer literal";
329 }
330 decorations[words[0]].set(
331 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
332 break;
333 default:
334 return emitError(loc: unknownLoc, message: "unhandled Decoration : '") << decorationName;
335 }
336 return success();
337}
338
339LogicalResult
340spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
341 // The binary layout of OpMemberDecorate is different comparing to OpDecorate
342 if (words.size() < 3) {
343 return emitError(loc: unknownLoc,
344 message: "OpMemberDecorate must have at least 3 operands");
345 }
346
347 auto decoration = static_cast<spirv::Decoration>(words[2]);
348 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
349 return emitError(loc: unknownLoc,
350 message: " missing offset specification in OpMemberDecorate with "
351 "Offset decoration");
352 }
353 ArrayRef<uint32_t> decorationOperands;
354 if (words.size() > 3) {
355 decorationOperands = words.slice(N: 3);
356 }
357 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
358 return success();
359}
360
361LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
362 if (words.size() < 3) {
363 return emitError(loc: unknownLoc, message: "OpMemberName must have at least 3 operands");
364 }
365 unsigned wordIndex = 2;
366 auto name = decodeStringLiteral(words, wordIndex);
367 if (wordIndex != words.size()) {
368 return emitError(loc: unknownLoc,
369 message: "unexpected trailing words in OpMemberName instruction");
370 }
371 memberNameMap[words[0]][words[1]] = name;
372 return success();
373}
374
375LogicalResult spirv::Deserializer::setFunctionArgAttrs(
376 uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
377 if (!decorations.contains(Val: argID)) {
378 argAttrs[argIndex] = DictionaryAttr::get(context, {});
379 return success();
380 }
381
382 spirv::DecorationAttr foundDecorationAttr;
383 for (NamedAttribute decAttr : decorations[argID]) {
384 for (auto decoration :
385 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
386 spirv::Decoration::AliasedPointer,
387 spirv::Decoration::RestrictPointer}) {
388
389 if (decAttr.getName() !=
390 getSymbolDecoration(stringifyDecoration(decoration)))
391 continue;
392
393 if (foundDecorationAttr)
394 return emitError(unknownLoc,
395 "more than one Aliased/Restrict decorations for "
396 "function argument with result <id> ")
397 << argID;
398
399 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
400 break;
401 }
402 }
403
404 if (!foundDecorationAttr)
405 return emitError(loc: unknownLoc, message: "unimplemented decoration support for "
406 "function argument with result <id> ")
407 << argID;
408
409 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
410 foundDecorationAttr);
411 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
412 return success();
413}
414
415LogicalResult
416spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
417 if (curFunction) {
418 return emitError(loc: unknownLoc, message: "found function inside function");
419 }
420
421 // Get the result type
422 if (operands.size() != 4) {
423 return emitError(loc: unknownLoc, message: "OpFunction must have 4 parameters");
424 }
425 Type resultType = getType(id: operands[0]);
426 if (!resultType) {
427 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
428 << operands[0];
429 }
430
431 uint32_t fnID = operands[1];
432 if (funcMap.count(fnID)) {
433 return emitError(loc: unknownLoc, message: "duplicate function definition/declaration");
434 }
435
436 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
437 if (!fnControl) {
438 return emitError(loc: unknownLoc, message: "unknown Function Control: ") << operands[2];
439 }
440
441 Type fnType = getType(id: operands[3]);
442 if (!fnType || !isa<FunctionType>(Val: fnType)) {
443 return emitError(loc: unknownLoc, message: "unknown function type from <id> ")
444 << operands[3];
445 }
446 auto functionType = cast<FunctionType>(fnType);
447
448 if ((isVoidType(type: resultType) && functionType.getNumResults() != 0) ||
449 (functionType.getNumResults() == 1 &&
450 functionType.getResult(0) != resultType)) {
451 return emitError(loc: unknownLoc, message: "mismatch in function type ")
452 << functionType << " and return type " << resultType << " specified";
453 }
454
455 std::string fnName = getFunctionSymbol(id: fnID);
456 auto funcOp = opBuilder.create<spirv::FuncOp>(
457 unknownLoc, fnName, functionType, fnControl.value());
458 // Processing other function attributes.
459 if (decorations.count(Val: fnID)) {
460 for (auto attr : decorations[fnID].getAttrs()) {
461 funcOp->setAttr(attr.getName(), attr.getValue());
462 }
463 }
464 curFunction = funcMap[fnID] = funcOp;
465 auto *entryBlock = funcOp.addEntryBlock();
466 LLVM_DEBUG({
467 logger.startLine()
468 << "//===-------------------------------------------===//\n";
469 logger.startLine() << "[fn] name: " << fnName << "\n";
470 logger.startLine() << "[fn] type: " << fnType << "\n";
471 logger.startLine() << "[fn] ID: " << fnID << "\n";
472 logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
473 logger.indent();
474 });
475
476 SmallVector<Attribute> argAttrs;
477 argAttrs.resize(functionType.getNumInputs());
478
479 // Parse the op argument instructions
480 if (functionType.getNumInputs()) {
481 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
482 auto argType = functionType.getInput(i);
483 spirv::Opcode opcode = spirv::Opcode::OpNop;
484 ArrayRef<uint32_t> operands;
485 if (failed(sliceInstruction(opcode, operands,
486 spirv::Opcode::OpFunctionParameter))) {
487 return failure();
488 }
489 if (opcode != spirv::Opcode::OpFunctionParameter) {
490 return emitError(
491 loc: unknownLoc,
492 message: "missing OpFunctionParameter instruction for argument ")
493 << i;
494 }
495 if (operands.size() != 2) {
496 return emitError(
497 loc: unknownLoc,
498 message: "expected result type and result <id> for OpFunctionParameter");
499 }
500 auto argDefinedType = getType(id: operands[0]);
501 if (!argDefinedType || argDefinedType != argType) {
502 return emitError(loc: unknownLoc,
503 message: "mismatch in argument type between function type "
504 "definition ")
505 << functionType << " and argument type definition "
506 << argDefinedType << " at argument " << i;
507 }
508 if (getValue(id: operands[1])) {
509 return emitError(loc: unknownLoc, message: "duplicate definition of result <id> ")
510 << operands[1];
511 }
512 if (failed(result: setFunctionArgAttrs(argID: operands[1], argAttrs, argIndex: i))) {
513 return failure();
514 }
515
516 auto argValue = funcOp.getArgument(i);
517 valueMap[operands[1]] = argValue;
518 }
519 }
520
521 if (llvm::any_of(argAttrs, [](Attribute attr) {
522 auto argAttr = cast<DictionaryAttr>(attr);
523 return !argAttr.empty();
524 }))
525 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
526
527 // entryBlock is needed to access the arguments, Once that is done, we can
528 // erase the block for functions with 'Import' LinkageAttributes, since these
529 // are essentially function declarations, so they have no body.
530 auto linkageAttr = funcOp.getLinkageAttributes();
531 auto hasImportLinkage =
532 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
533 spirv::LinkageType::Import);
534 if (hasImportLinkage)
535 funcOp.eraseBody();
536
537 // RAII guard to reset the insertion point to the module's region after
538 // deserializing the body of this function.
539 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
540
541 spirv::Opcode opcode = spirv::Opcode::OpNop;
542 ArrayRef<uint32_t> instOperands;
543
544 // Special handling for the entry block. We need to make sure it starts with
545 // an OpLabel instruction. The entry block takes the same parameters as the
546 // function. All other blocks do not take any parameter. We have already
547 // created the entry block, here we need to register it to the correct label
548 // <id>.
549 if (failed(sliceInstruction(opcode, instOperands,
550 spirv::Opcode::OpFunctionEnd))) {
551 return failure();
552 }
553 if (opcode == spirv::Opcode::OpFunctionEnd) {
554 return processFunctionEnd(operands: instOperands);
555 }
556 if (opcode != spirv::Opcode::OpLabel) {
557 return emitError(loc: unknownLoc, message: "a basic block must start with OpLabel");
558 }
559 if (instOperands.size() != 1) {
560 return emitError(loc: unknownLoc, message: "OpLabel should only have result <id>");
561 }
562 blockMap[instOperands[0]] = entryBlock;
563 if (failed(result: processLabel(operands: instOperands))) {
564 return failure();
565 }
566
567 // Then process all the other instructions in the function until we hit
568 // OpFunctionEnd.
569 while (succeeded(sliceInstruction(opcode, instOperands,
570 spirv::Opcode::OpFunctionEnd)) &&
571 opcode != spirv::Opcode::OpFunctionEnd) {
572 if (failed(processInstruction(opcode, instOperands))) {
573 return failure();
574 }
575 }
576 if (opcode != spirv::Opcode::OpFunctionEnd) {
577 return failure();
578 }
579
580 return processFunctionEnd(operands: instOperands);
581}
582
583LogicalResult
584spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
585 // Process OpFunctionEnd.
586 if (!operands.empty()) {
587 return emitError(loc: unknownLoc, message: "unexpected operands for OpFunctionEnd");
588 }
589
590 // Wire up block arguments from OpPhi instructions.
591 // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
592 // ops.
593 if (failed(result: wireUpBlockArgument()) || failed(result: structurizeControlFlow())) {
594 return failure();
595 }
596
597 curBlock = nullptr;
598 curFunction = std::nullopt;
599
600 LLVM_DEBUG({
601 logger.unindent();
602 logger.startLine()
603 << "//===-------------------------------------------===//\n";
604 });
605 return success();
606}
607
608std::optional<std::pair<Attribute, Type>>
609spirv::Deserializer::getConstant(uint32_t id) {
610 auto constIt = constantMap.find(Val: id);
611 if (constIt == constantMap.end())
612 return std::nullopt;
613 return constIt->getSecond();
614}
615
616std::optional<spirv::SpecConstOperationMaterializationInfo>
617spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
618 auto constIt = specConstOperationMap.find(Val: id);
619 if (constIt == specConstOperationMap.end())
620 return std::nullopt;
621 return constIt->getSecond();
622}
623
624std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
625 auto funcName = nameMap.lookup(Val: id).str();
626 if (funcName.empty()) {
627 funcName = "spirv_fn_" + std::to_string(val: id);
628 }
629 return funcName;
630}
631
632std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
633 auto constName = nameMap.lookup(Val: id).str();
634 if (constName.empty()) {
635 constName = "spirv_spec_const_" + std::to_string(val: id);
636 }
637 return constName;
638}
639
640spirv::SpecConstantOp
641spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
642 TypedAttr defaultValue) {
643 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(id: resultID));
644 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
645 defaultValue);
646 if (decorations.count(Val: resultID)) {
647 for (auto attr : decorations[resultID].getAttrs())
648 op->setAttr(attr.getName(), attr.getValue());
649 }
650 specConstMap[resultID] = op;
651 return op;
652}
653
654LogicalResult
655spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
656 unsigned wordIndex = 0;
657 if (operands.size() < 3) {
658 return emitError(
659 loc: unknownLoc,
660 message: "OpVariable needs at least 3 operands, type, <id> and storage class");
661 }
662
663 // Result Type.
664 auto type = getType(id: operands[wordIndex]);
665 if (!type) {
666 return emitError(loc: unknownLoc, message: "unknown result type <id> : ")
667 << operands[wordIndex];
668 }
669 auto ptrType = dyn_cast<spirv::PointerType>(Val&: type);
670 if (!ptrType) {
671 return emitError(loc: unknownLoc,
672 message: "expected a result type <id> to be a spirv.ptr, found : ")
673 << type;
674 }
675 wordIndex++;
676
677 // Result <id>.
678 auto variableID = operands[wordIndex];
679 auto variableName = nameMap.lookup(Val: variableID).str();
680 if (variableName.empty()) {
681 variableName = "spirv_var_" + std::to_string(val: variableID);
682 }
683 wordIndex++;
684
685 // Storage class.
686 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
687 if (ptrType.getStorageClass() != storageClass) {
688 return emitError(loc: unknownLoc, message: "mismatch in storage class of pointer type ")
689 << type << " and that specified in OpVariable instruction : "
690 << stringifyStorageClass(storageClass);
691 }
692 wordIndex++;
693
694 // Initializer.
695 FlatSymbolRefAttr initializer = nullptr;
696
697 if (wordIndex < operands.size()) {
698 Operation *op = nullptr;
699
700 if (auto initOp = getGlobalVariable(operands[wordIndex]))
701 op = initOp;
702 else if (auto initOp = getSpecConstant(operands[wordIndex]))
703 op = initOp;
704 else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
705 op = initOp;
706 else
707 return emitError(loc: unknownLoc, message: "unknown <id> ")
708 << operands[wordIndex] << "used as initializer";
709
710 initializer = SymbolRefAttr::get(op);
711 wordIndex++;
712 }
713 if (wordIndex != operands.size()) {
714 return emitError(loc: unknownLoc,
715 message: "found more operands than expected when deserializing "
716 "OpVariable instruction, only ")
717 << wordIndex << " of " << operands.size() << " processed";
718 }
719 auto loc = createFileLineColLoc(opBuilder);
720 auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
721 loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
722 initializer);
723
724 // Decorations.
725 if (decorations.count(Val: variableID)) {
726 for (auto attr : decorations[variableID].getAttrs())
727 varOp->setAttr(attr.getName(), attr.getValue());
728 }
729 globalVariableMap[variableID] = varOp;
730 return success();
731}
732
733IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
734 auto constInfo = getConstant(id);
735 if (!constInfo) {
736 return nullptr;
737 }
738 return dyn_cast<IntegerAttr>(constInfo->first);
739}
740
741LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
742 if (operands.size() < 2) {
743 return emitError(loc: unknownLoc, message: "OpName needs at least 2 operands");
744 }
745 if (!nameMap.lookup(Val: operands[0]).empty()) {
746 return emitError(loc: unknownLoc, message: "duplicate name found for result <id> ")
747 << operands[0];
748 }
749 unsigned wordIndex = 1;
750 StringRef name = decodeStringLiteral(words: operands, wordIndex);
751 if (wordIndex != operands.size()) {
752 return emitError(loc: unknownLoc,
753 message: "unexpected trailing words in OpName instruction");
754 }
755 nameMap[operands[0]] = name;
756 return success();
757}
758
759//===----------------------------------------------------------------------===//
760// Type
761//===----------------------------------------------------------------------===//
762
763LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
764 ArrayRef<uint32_t> operands) {
765 if (operands.empty()) {
766 return emitError(unknownLoc, "type instruction with opcode ")
767 << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
768 }
769
770 /// TODO: Types might be forward declared in some instructions and need to be
771 /// handled appropriately.
772 if (typeMap.count(Val: operands[0])) {
773 return emitError(loc: unknownLoc, message: "duplicate definition for result <id> ")
774 << operands[0];
775 }
776
777 switch (opcode) {
778 case spirv::Opcode::OpTypeVoid:
779 if (operands.size() != 1)
780 return emitError(loc: unknownLoc, message: "OpTypeVoid must have no parameters");
781 typeMap[operands[0]] = opBuilder.getNoneType();
782 break;
783 case spirv::Opcode::OpTypeBool:
784 if (operands.size() != 1)
785 return emitError(loc: unknownLoc, message: "OpTypeBool must have no parameters");
786 typeMap[operands[0]] = opBuilder.getI1Type();
787 break;
788 case spirv::Opcode::OpTypeInt: {
789 if (operands.size() != 3)
790 return emitError(
791 loc: unknownLoc, message: "OpTypeInt must have bitwidth and signedness parameters");
792
793 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
794 // to preserve or validate.
795 // 0 indicates unsigned, or no signedness semantics
796 // 1 indicates signed semantics."
797 //
798 // So we cannot differentiate signless and unsigned integers; always use
799 // signless semantics for such cases.
800 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
801 : IntegerType::SignednessSemantics::Signless;
802 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
803 } break;
804 case spirv::Opcode::OpTypeFloat: {
805 if (operands.size() != 2)
806 return emitError(loc: unknownLoc, message: "OpTypeFloat must have bitwidth parameter");
807
808 Type floatTy;
809 switch (operands[1]) {
810 case 16:
811 floatTy = opBuilder.getF16Type();
812 break;
813 case 32:
814 floatTy = opBuilder.getF32Type();
815 break;
816 case 64:
817 floatTy = opBuilder.getF64Type();
818 break;
819 default:
820 return emitError(loc: unknownLoc, message: "unsupported OpTypeFloat bitwidth: ")
821 << operands[1];
822 }
823 typeMap[operands[0]] = floatTy;
824 } break;
825 case spirv::Opcode::OpTypeVector: {
826 if (operands.size() != 3) {
827 return emitError(
828 loc: unknownLoc,
829 message: "OpTypeVector must have element type and count parameters");
830 }
831 Type elementTy = getType(id: operands[1]);
832 if (!elementTy) {
833 return emitError(loc: unknownLoc, message: "OpTypeVector references undefined <id> ")
834 << operands[1];
835 }
836 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
837 } break;
838 case spirv::Opcode::OpTypePointer: {
839 return processOpTypePointer(operands);
840 } break;
841 case spirv::Opcode::OpTypeArray:
842 return processArrayType(operands);
843 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
844 return processCooperativeMatrixTypeKHR(operands);
845 case spirv::Opcode::OpTypeFunction:
846 return processFunctionType(operands);
847 case spirv::Opcode::OpTypeJointMatrixINTEL:
848 return processJointMatrixType(operands);
849 case spirv::Opcode::OpTypeImage:
850 return processImageType(operands);
851 case spirv::Opcode::OpTypeSampledImage:
852 return processSampledImageType(operands);
853 case spirv::Opcode::OpTypeRuntimeArray:
854 return processRuntimeArrayType(operands);
855 case spirv::Opcode::OpTypeStruct:
856 return processStructType(operands);
857 case spirv::Opcode::OpTypeMatrix:
858 return processMatrixType(operands);
859 default:
860 return emitError(loc: unknownLoc, message: "unhandled type instruction");
861 }
862 return success();
863}
864
865LogicalResult
866spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
867 if (operands.size() != 3)
868 return emitError(loc: unknownLoc, message: "OpTypePointer must have two parameters");
869
870 auto pointeeType = getType(id: operands[2]);
871 if (!pointeeType)
872 return emitError(loc: unknownLoc, message: "unknown OpTypePointer pointee type <id> ")
873 << operands[2];
874
875 uint32_t typePointerID = operands[0];
876 auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
877 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
878
879 for (auto *deferredStructIt = std::begin(cont&: deferredStructTypesInfos);
880 deferredStructIt != std::end(cont&: deferredStructTypesInfos);) {
881 for (auto *unresolvedMemberIt =
882 std::begin(cont&: deferredStructIt->unresolvedMemberTypes);
883 unresolvedMemberIt !=
884 std::end(cont&: deferredStructIt->unresolvedMemberTypes);) {
885 if (unresolvedMemberIt->first == typePointerID) {
886 // The newly constructed pointer type can resolve one of the
887 // deferred struct type members; update the memberTypes list and
888 // clean the unresolvedMemberTypes list accordingly.
889 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
890 typeMap[typePointerID];
891 unresolvedMemberIt =
892 deferredStructIt->unresolvedMemberTypes.erase(CI: unresolvedMemberIt);
893 } else {
894 ++unresolvedMemberIt;
895 }
896 }
897
898 if (deferredStructIt->unresolvedMemberTypes.empty()) {
899 // All deferred struct type members are now resolved, set the struct body.
900 auto structType = deferredStructIt->deferredStructType;
901
902 assert(structType && "expected a spirv::StructType");
903 assert(structType.isIdentified() && "expected an indentified struct");
904
905 if (failed(result: structType.trySetBody(
906 memberTypes: deferredStructIt->memberTypes, offsetInfo: deferredStructIt->offsetInfo,
907 memberDecorations: deferredStructIt->memberDecorationsInfo)))
908 return failure();
909
910 deferredStructIt = deferredStructTypesInfos.erase(CI: deferredStructIt);
911 } else {
912 ++deferredStructIt;
913 }
914 }
915
916 return success();
917}
918
919LogicalResult
920spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
921 if (operands.size() != 3) {
922 return emitError(loc: unknownLoc,
923 message: "OpTypeArray must have element type and count parameters");
924 }
925
926 Type elementTy = getType(id: operands[1]);
927 if (!elementTy) {
928 return emitError(loc: unknownLoc, message: "OpTypeArray references undefined <id> ")
929 << operands[1];
930 }
931
932 unsigned count = 0;
933 // TODO: The count can also come frome a specialization constant.
934 auto countInfo = getConstant(id: operands[2]);
935 if (!countInfo) {
936 return emitError(loc: unknownLoc, message: "OpTypeArray count <id> ")
937 << operands[2] << "can only come from normal constant right now";
938 }
939
940 if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
941 count = intVal.getValue().getZExtValue();
942 } else {
943 return emitError(loc: unknownLoc, message: "OpTypeArray count must come from a "
944 "scalar integer constant instruction");
945 }
946
947 typeMap[operands[0]] = spirv::ArrayType::get(
948 elementType: elementTy, elementCount: count, stride: typeDecorations.lookup(Val: operands[0]));
949 return success();
950}
951
952LogicalResult
953spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
954 assert(!operands.empty() && "No operands for processing function type");
955 if (operands.size() == 1) {
956 return emitError(loc: unknownLoc, message: "missing return type for OpTypeFunction");
957 }
958 auto returnType = getType(id: operands[1]);
959 if (!returnType) {
960 return emitError(loc: unknownLoc, message: "unknown return type in OpTypeFunction");
961 }
962 SmallVector<Type, 1> argTypes;
963 for (size_t i = 2, e = operands.size(); i < e; ++i) {
964 auto ty = getType(id: operands[i]);
965 if (!ty) {
966 return emitError(loc: unknownLoc, message: "unknown argument type in OpTypeFunction");
967 }
968 argTypes.push_back(Elt: ty);
969 }
970 ArrayRef<Type> returnTypes;
971 if (!isVoidType(type: returnType)) {
972 returnTypes = llvm::ArrayRef(returnType);
973 }
974 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
975 return success();
976}
977
978LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
979 ArrayRef<uint32_t> operands) {
980 if (operands.size() != 6) {
981 return emitError(loc: unknownLoc,
982 message: "OpTypeCooperativeMatrixKHR must have element type, "
983 "scope, row and column parameters, and use");
984 }
985
986 Type elementTy = getType(id: operands[1]);
987 if (!elementTy) {
988 return emitError(loc: unknownLoc,
989 message: "OpTypeCooperativeMatrixKHR references undefined <id> ")
990 << operands[1];
991 }
992
993 std::optional<spirv::Scope> scope =
994 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
995 if (!scope) {
996 return emitError(
997 loc: unknownLoc,
998 message: "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
999 << operands[2];
1000 }
1001
1002 unsigned rows = getConstantInt(operands[3]).getInt();
1003 unsigned columns = getConstantInt(operands[4]).getInt();
1004
1005 std::optional<spirv::CooperativeMatrixUseKHR> use =
1006 spirv::symbolizeCooperativeMatrixUseKHR(
1007 getConstantInt(operands[5]).getInt());
1008 if (!use) {
1009 return emitError(
1010 loc: unknownLoc,
1011 message: "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1012 << operands[5];
1013 }
1014
1015 typeMap[operands[0]] =
1016 spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1017 return success();
1018}
1019
1020LogicalResult
1021spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
1022 if (operands.size() != 6) {
1023 return emitError(loc: unknownLoc, message: "OpTypeJointMatrix must have element "
1024 "type and row x column parameters");
1025 }
1026
1027 Type elementTy = getType(id: operands[1]);
1028 if (!elementTy) {
1029 return emitError(loc: unknownLoc, message: "OpTypeJointMatrix references undefined <id> ")
1030 << operands[1];
1031 }
1032
1033 auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
1034 if (!scope) {
1035 return emitError(loc: unknownLoc,
1036 message: "OpTypeJointMatrix references undefined scope <id> ")
1037 << operands[5];
1038 }
1039 auto matrixLayout =
1040 spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
1041 if (!matrixLayout) {
1042 return emitError(loc: unknownLoc,
1043 message: "OpTypeJointMatrix references undefined scope <id> ")
1044 << operands[4];
1045 }
1046 unsigned rows = getConstantInt(operands[2]).getInt();
1047 unsigned columns = getConstantInt(operands[3]).getInt();
1048
1049 typeMap[operands[0]] = spirv::JointMatrixINTELType::get(
1050 elementTy, scope.value(), rows, columns, matrixLayout.value());
1051 return success();
1052}
1053
1054LogicalResult
1055spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1056 if (operands.size() != 2) {
1057 return emitError(loc: unknownLoc, message: "OpTypeRuntimeArray must have two operands");
1058 }
1059 Type memberType = getType(id: operands[1]);
1060 if (!memberType) {
1061 return emitError(loc: unknownLoc,
1062 message: "OpTypeRuntimeArray references undefined <id> ")
1063 << operands[1];
1064 }
1065 typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1066 elementType: memberType, stride: typeDecorations.lookup(Val: operands[0]));
1067 return success();
1068}
1069
1070LogicalResult
1071spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1072 // TODO: Find a way to handle identified structs when debug info is stripped.
1073
1074 if (operands.empty()) {
1075 return emitError(loc: unknownLoc, message: "OpTypeStruct must have at least result <id>");
1076 }
1077
1078 if (operands.size() == 1) {
1079 // Handle empty struct.
1080 typeMap[operands[0]] =
1081 spirv::StructType::getEmpty(context, identifier: nameMap.lookup(Val: operands[0]).str());
1082 return success();
1083 }
1084
1085 // First element is operand ID, second element is member index in the struct.
1086 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1087 SmallVector<Type, 4> memberTypes;
1088
1089 for (auto op : llvm::drop_begin(RangeOrContainer&: operands, N: 1)) {
1090 Type memberType = getType(id: op);
1091 bool typeForwardPtr = (typeForwardPointerIDs.count(key: op) != 0);
1092
1093 if (!memberType && !typeForwardPtr)
1094 return emitError(loc: unknownLoc, message: "OpTypeStruct references undefined <id> ")
1095 << op;
1096
1097 if (!memberType)
1098 unresolvedMemberTypes.emplace_back(Args&: op, Args: memberTypes.size());
1099
1100 memberTypes.push_back(Elt: memberType);
1101 }
1102
1103 SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
1104 SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
1105 if (memberDecorationMap.count(operands[0])) {
1106 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1107 for (auto memberIndex : llvm::seq<uint32_t>(Begin: 0, End: memberTypes.size())) {
1108 if (allMemberDecorations.count(memberIndex)) {
1109 for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1110 // Check for offset.
1111 if (memberDecoration.first == spirv::Decoration::Offset) {
1112 // If offset info is empty, resize to the number of members;
1113 if (offsetInfo.empty()) {
1114 offsetInfo.resize(memberTypes.size());
1115 }
1116 offsetInfo[memberIndex] = memberDecoration.second[0];
1117 } else {
1118 if (!memberDecoration.second.empty()) {
1119 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
1120 memberDecoration.first,
1121 memberDecoration.second[0]);
1122 } else {
1123 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
1124 memberDecoration.first, 0);
1125 }
1126 }
1127 }
1128 }
1129 }
1130 }
1131
1132 uint32_t structID = operands[0];
1133 std::string structIdentifier = nameMap.lookup(Val: structID).str();
1134
1135 if (structIdentifier.empty()) {
1136 assert(unresolvedMemberTypes.empty() &&
1137 "didn't expect unresolved member types");
1138 typeMap[structID] =
1139 spirv::StructType::get(memberTypes, offsetInfo, memberDecorations: memberDecorationsInfo);
1140 } else {
1141 auto structTy = spirv::StructType::getIdentified(context, identifier: structIdentifier);
1142 typeMap[structID] = structTy;
1143
1144 if (!unresolvedMemberTypes.empty())
1145 deferredStructTypesInfos.push_back(Elt: {.deferredStructType: structTy, .unresolvedMemberTypes: unresolvedMemberTypes,
1146 .memberTypes: memberTypes, .offsetInfo: offsetInfo,
1147 .memberDecorationsInfo: memberDecorationsInfo});
1148 else if (failed(result: structTy.trySetBody(memberTypes, offsetInfo,
1149 memberDecorations: memberDecorationsInfo)))
1150 return failure();
1151 }
1152
1153 // TODO: Update StructType to have member name as attribute as
1154 // well.
1155 return success();
1156}
1157
1158LogicalResult
1159spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1160 if (operands.size() != 3) {
1161 // Three operands are needed: result_id, column_type, and column_count
1162 return emitError(loc: unknownLoc, message: "OpTypeMatrix must have 3 operands"
1163 " (result_id, column_type, and column_count)");
1164 }
1165 // Matrix columns must be of vector type
1166 Type elementTy = getType(id: operands[1]);
1167 if (!elementTy) {
1168 return emitError(loc: unknownLoc,
1169 message: "OpTypeMatrix references undefined column type.")
1170 << operands[1];
1171 }
1172
1173 uint32_t colsCount = operands[2];
1174 typeMap[operands[0]] = spirv::MatrixType::get(columnType: elementTy, columnCount: colsCount);
1175 return success();
1176}
1177
1178LogicalResult
1179spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1180 if (operands.size() != 2)
1181 return emitError(loc: unknownLoc,
1182 message: "OpTypeForwardPointer instruction must have two operands");
1183
1184 typeForwardPointerIDs.insert(X: operands[0]);
1185 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1186 // instruction that defines the actual type.
1187
1188 return success();
1189}
1190
1191LogicalResult
1192spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1193 // TODO: Add support for Access Qualifier.
1194 if (operands.size() != 8)
1195 return emitError(
1196 loc: unknownLoc,
1197 message: "OpTypeImage with non-eight operands are not supported yet");
1198
1199 Type elementTy = getType(id: operands[1]);
1200 if (!elementTy)
1201 return emitError(loc: unknownLoc, message: "OpTypeImage references undefined <id>: ")
1202 << operands[1];
1203
1204 auto dim = spirv::symbolizeDim(operands[2]);
1205 if (!dim)
1206 return emitError(loc: unknownLoc, message: "unknown Dim for OpTypeImage: ")
1207 << operands[2];
1208
1209 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1210 if (!depthInfo)
1211 return emitError(loc: unknownLoc, message: "unknown Depth for OpTypeImage: ")
1212 << operands[3];
1213
1214 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1215 if (!arrayedInfo)
1216 return emitError(loc: unknownLoc, message: "unknown Arrayed for OpTypeImage: ")
1217 << operands[4];
1218
1219 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1220 if (!samplingInfo)
1221 return emitError(loc: unknownLoc, message: "unknown MS for OpTypeImage: ") << operands[5];
1222
1223 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1224 if (!samplerUseInfo)
1225 return emitError(loc: unknownLoc, message: "unknown Sampled for OpTypeImage: ")
1226 << operands[6];
1227
1228 auto format = spirv::symbolizeImageFormat(operands[7]);
1229 if (!format)
1230 return emitError(loc: unknownLoc, message: "unknown Format for OpTypeImage: ")
1231 << operands[7];
1232
1233 typeMap[operands[0]] = spirv::ImageType::get(
1234 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1235 samplingInfo.value(), samplerUseInfo.value(), format.value());
1236 return success();
1237}
1238
1239LogicalResult
1240spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1241 if (operands.size() != 2)
1242 return emitError(loc: unknownLoc, message: "OpTypeSampledImage must have two operands");
1243
1244 Type elementTy = getType(id: operands[1]);
1245 if (!elementTy)
1246 return emitError(loc: unknownLoc,
1247 message: "OpTypeSampledImage references undefined <id>: ")
1248 << operands[1];
1249
1250 typeMap[operands[0]] = spirv::SampledImageType::get(imageType: elementTy);
1251 return success();
1252}
1253
1254//===----------------------------------------------------------------------===//
1255// Constant
1256//===----------------------------------------------------------------------===//
1257
1258LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1259 bool isSpec) {
1260 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1261
1262 if (operands.size() < 2) {
1263 return emitError(loc: unknownLoc)
1264 << opname << " must have type <id> and result <id>";
1265 }
1266 if (operands.size() < 3) {
1267 return emitError(loc: unknownLoc)
1268 << opname << " must have at least 1 more parameter";
1269 }
1270
1271 Type resultType = getType(id: operands[0]);
1272 if (!resultType) {
1273 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1274 << operands[0];
1275 }
1276
1277 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1278 if (bitwidth == 64) {
1279 if (operands.size() == 4) {
1280 return success();
1281 }
1282 return emitError(loc: unknownLoc)
1283 << opname << " should have 2 parameters for 64-bit values";
1284 }
1285 if (bitwidth <= 32) {
1286 if (operands.size() == 3) {
1287 return success();
1288 }
1289
1290 return emitError(loc: unknownLoc)
1291 << opname
1292 << " should have 1 parameter for values with no more than 32 bits";
1293 }
1294 return emitError(loc: unknownLoc, message: "unsupported OpConstant bitwidth: ")
1295 << bitwidth;
1296 };
1297
1298 auto resultID = operands[1];
1299
1300 if (auto intType = dyn_cast<IntegerType>(resultType)) {
1301 auto bitwidth = intType.getWidth();
1302 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1303 return failure();
1304 }
1305
1306 APInt value;
1307 if (bitwidth == 64) {
1308 // 64-bit integers are represented with two SPIR-V words. According to
1309 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1310 // literal’s low-order words appear first."
1311 struct DoubleWord {
1312 uint32_t word1;
1313 uint32_t word2;
1314 } words = {.word1: operands[2], .word2: operands[3]};
1315 value = APInt(64, llvm::bit_cast<uint64_t>(from: words), /*isSigned=*/true);
1316 } else if (bitwidth <= 32) {
1317 value = APInt(bitwidth, operands[2], /*isSigned=*/true);
1318 }
1319
1320 auto attr = opBuilder.getIntegerAttr(intType, value);
1321
1322 if (isSpec) {
1323 createSpecConstant(unknownLoc, resultID, attr);
1324 } else {
1325 // For normal constants, we just record the attribute (and its type) for
1326 // later materialization at use sites.
1327 constantMap.try_emplace(resultID, attr, intType);
1328 }
1329
1330 return success();
1331 }
1332
1333 if (auto floatType = dyn_cast<FloatType>(Val&: resultType)) {
1334 auto bitwidth = floatType.getWidth();
1335 if (failed(result: checkOperandSizeForBitwidth(bitwidth))) {
1336 return failure();
1337 }
1338
1339 APFloat value(0.f);
1340 if (floatType.isF64()) {
1341 // Double values are represented with two SPIR-V words. According to
1342 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1343 // literal’s low-order words appear first."
1344 struct DoubleWord {
1345 uint32_t word1;
1346 uint32_t word2;
1347 } words = {.word1: operands[2], .word2: operands[3]};
1348 value = APFloat(llvm::bit_cast<double>(from: words));
1349 } else if (floatType.isF32()) {
1350 value = APFloat(llvm::bit_cast<float>(from: operands[2]));
1351 } else if (floatType.isF16()) {
1352 APInt data(16, operands[2]);
1353 value = APFloat(APFloat::IEEEhalf(), data);
1354 }
1355
1356 auto attr = opBuilder.getFloatAttr(floatType, value);
1357 if (isSpec) {
1358 createSpecConstant(unknownLoc, resultID, attr);
1359 } else {
1360 // For normal constants, we just record the attribute (and its type) for
1361 // later materialization at use sites.
1362 constantMap.try_emplace(resultID, attr, floatType);
1363 }
1364
1365 return success();
1366 }
1367
1368 return emitError(loc: unknownLoc, message: "OpConstant can only generate values of "
1369 "scalar integer or floating-point type");
1370}
1371
1372LogicalResult spirv::Deserializer::processConstantBool(
1373 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1374 if (operands.size() != 2) {
1375 return emitError(loc: unknownLoc, message: "Op")
1376 << (isSpec ? "Spec" : "") << "Constant"
1377 << (isTrue ? "True" : "False")
1378 << " must have type <id> and result <id>";
1379 }
1380
1381 auto attr = opBuilder.getBoolAttr(value: isTrue);
1382 auto resultID = operands[1];
1383 if (isSpec) {
1384 createSpecConstant(unknownLoc, resultID, attr);
1385 } else {
1386 // For normal constants, we just record the attribute (and its type) for
1387 // later materialization at use sites.
1388 constantMap.try_emplace(Key: resultID, Args&: attr, Args: opBuilder.getI1Type());
1389 }
1390
1391 return success();
1392}
1393
1394LogicalResult
1395spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1396 if (operands.size() < 2) {
1397 return emitError(loc: unknownLoc,
1398 message: "OpConstantComposite must have type <id> and result <id>");
1399 }
1400 if (operands.size() < 3) {
1401 return emitError(loc: unknownLoc,
1402 message: "OpConstantComposite must have at least 1 parameter");
1403 }
1404
1405 Type resultType = getType(id: operands[0]);
1406 if (!resultType) {
1407 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1408 << operands[0];
1409 }
1410
1411 SmallVector<Attribute, 4> elements;
1412 elements.reserve(N: operands.size() - 2);
1413 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1414 auto elementInfo = getConstant(id: operands[i]);
1415 if (!elementInfo) {
1416 return emitError(loc: unknownLoc, message: "OpConstantComposite component <id> ")
1417 << operands[i] << " must come from a normal constant";
1418 }
1419 elements.push_back(Elt: elementInfo->first);
1420 }
1421
1422 auto resultID = operands[1];
1423 if (auto vectorType = dyn_cast<VectorType>(resultType)) {
1424 auto attr = DenseElementsAttr::get(vectorType, elements);
1425 // For normal constants, we just record the attribute (and its type) for
1426 // later materialization at use sites.
1427 constantMap.try_emplace(resultID, attr, resultType);
1428 } else if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: resultType)) {
1429 auto attr = opBuilder.getArrayAttr(elements);
1430 constantMap.try_emplace(resultID, attr, resultType);
1431 } else {
1432 return emitError(loc: unknownLoc, message: "unsupported OpConstantComposite type: ")
1433 << resultType;
1434 }
1435
1436 return success();
1437}
1438
1439LogicalResult
1440spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1441 if (operands.size() < 2) {
1442 return emitError(loc: unknownLoc,
1443 message: "OpConstantComposite must have type <id> and result <id>");
1444 }
1445 if (operands.size() < 3) {
1446 return emitError(loc: unknownLoc,
1447 message: "OpConstantComposite must have at least 1 parameter");
1448 }
1449
1450 Type resultType = getType(id: operands[0]);
1451 if (!resultType) {
1452 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1453 << operands[0];
1454 }
1455
1456 auto resultID = operands[1];
1457 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(id: resultID));
1458
1459 SmallVector<Attribute, 4> elements;
1460 elements.reserve(N: operands.size() - 2);
1461 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1462 auto elementInfo = getSpecConstant(operands[i]);
1463 elements.push_back(SymbolRefAttr::get(elementInfo));
1464 }
1465
1466 auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1467 unknownLoc, TypeAttr::get(resultType), symName,
1468 opBuilder.getArrayAttr(elements));
1469 specConstCompositeMap[resultID] = op;
1470
1471 return success();
1472}
1473
1474LogicalResult
1475spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1476 if (operands.size() < 3)
1477 return emitError(loc: unknownLoc, message: "OpConstantOperation must have type <id>, "
1478 "result <id>, and operand opcode");
1479
1480 uint32_t resultTypeID = operands[0];
1481
1482 if (!getType(id: resultTypeID))
1483 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1484 << resultTypeID;
1485
1486 uint32_t resultID = operands[1];
1487 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1488 auto emplaceResult = specConstOperationMap.try_emplace(
1489 Key: resultID,
1490 Args: SpecConstOperationMaterializationInfo{
1491 enclosedOpcode, resultTypeID,
1492 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1493
1494 if (!emplaceResult.second)
1495 return emitError(loc: unknownLoc, message: "value with <id>: ")
1496 << resultID << " is probably defined before.";
1497
1498 return success();
1499}
1500
1501Value spirv::Deserializer::materializeSpecConstantOperation(
1502 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1503 ArrayRef<uint32_t> enclosedOpOperands) {
1504
1505 Type resultType = getType(id: resultTypeID);
1506
1507 // Instructions wrapped by OpSpecConstantOp need an ID for their
1508 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1509 // dialect wrapped op. For that purpose, a new value map is created and "fake"
1510 // ID in that map is assigned to the result of the enclosed instruction. Note
1511 // that there is no need to update this fake ID since we only need to
1512 // reference the created Value for the enclosed op from the spv::YieldOp
1513 // created later in this method (both of which are the only values in their
1514 // region: the SpecConstantOperation's region). If we encounter another
1515 // SpecConstantOperation in the module, we simply re-use the fake ID since the
1516 // previous Value assigned to it isn't visible in the current scope anyway.
1517 DenseMap<uint32_t, Value> newValueMap;
1518 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1519 constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1520
1521 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1522 enclosedOpResultTypeAndOperands.push_back(Elt: resultTypeID);
1523 enclosedOpResultTypeAndOperands.push_back(Elt: fakeID);
1524 enclosedOpResultTypeAndOperands.append(in_start: enclosedOpOperands.begin(),
1525 in_end: enclosedOpOperands.end());
1526
1527 // Process enclosed instruction before creating the enclosing
1528 // specConstantOperation (and its region). This way, references to constants,
1529 // global variables, and spec constants will be materialized outside the new
1530 // op's region. For more info, see Deserializer::getValue's implementation.
1531 if (failed(
1532 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1533 return Value();
1534
1535 // Since the enclosed op is emitted in the current block, split it in a
1536 // separate new block.
1537 Block *enclosedBlock = curBlock->splitBlock(splitBeforeOp: &curBlock->back());
1538
1539 auto loc = createFileLineColLoc(opBuilder);
1540 auto specConstOperationOp =
1541 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1542
1543 Region &body = specConstOperationOp.getBody();
1544 // Move the new block into SpecConstantOperation's body.
1545 body.getBlocks().splice(where: body.end(), L2&: curBlock->getParent()->getBlocks(),
1546 first: Region::iterator(enclosedBlock));
1547 Block &block = body.back();
1548
1549 // RAII guard to reset the insertion point to the module's region after
1550 // deserializing the body of the specConstantOperation.
1551 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1552 opBuilder.setInsertionPointToEnd(&block);
1553
1554 opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1555 return specConstOperationOp.getResult();
1556}
1557
1558LogicalResult
1559spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1560 if (operands.size() != 2) {
1561 return emitError(loc: unknownLoc,
1562 message: "OpConstantNull must have type <id> and result <id>");
1563 }
1564
1565 Type resultType = getType(id: operands[0]);
1566 if (!resultType) {
1567 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1568 << operands[0];
1569 }
1570
1571 auto resultID = operands[1];
1572 if (resultType.isIntOrFloat() || isa<VectorType>(Val: resultType)) {
1573 auto attr = opBuilder.getZeroAttr(resultType);
1574 // For normal constants, we just record the attribute (and its type) for
1575 // later materialization at use sites.
1576 constantMap.try_emplace(resultID, attr, resultType);
1577 return success();
1578 }
1579
1580 return emitError(loc: unknownLoc, message: "unsupported OpConstantNull type: ")
1581 << resultType;
1582}
1583
1584//===----------------------------------------------------------------------===//
1585// Control flow
1586//===----------------------------------------------------------------------===//
1587
1588Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1589 if (auto *block = getBlock(id)) {
1590 LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1591 << " @ " << block << "\n");
1592 return block;
1593 }
1594
1595 // We don't know where this block will be placed finally (in a
1596 // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1597 // function for now and sort out the proper place later.
1598 auto *block = curFunction->addBlock();
1599 LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1600 << " @ " << block << "\n");
1601 return blockMap[id] = block;
1602}
1603
1604LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1605 if (!curBlock) {
1606 return emitError(loc: unknownLoc, message: "OpBranch must appear inside a block");
1607 }
1608
1609 if (operands.size() != 1) {
1610 return emitError(loc: unknownLoc, message: "OpBranch must take exactly one target label");
1611 }
1612
1613 auto *target = getOrCreateBlock(id: operands[0]);
1614 auto loc = createFileLineColLoc(opBuilder);
1615 // The preceding instruction for the OpBranch instruction could be an
1616 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1617 // the same OpLine information.
1618 opBuilder.create<spirv::BranchOp>(loc, target);
1619
1620 clearDebugLine();
1621 return success();
1622}
1623
1624LogicalResult
1625spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1626 if (!curBlock) {
1627 return emitError(loc: unknownLoc,
1628 message: "OpBranchConditional must appear inside a block");
1629 }
1630
1631 if (operands.size() != 3 && operands.size() != 5) {
1632 return emitError(loc: unknownLoc,
1633 message: "OpBranchConditional must have condition, true label, "
1634 "false label, and optionally two branch weights");
1635 }
1636
1637 auto condition = getValue(id: operands[0]);
1638 auto *trueBlock = getOrCreateBlock(id: operands[1]);
1639 auto *falseBlock = getOrCreateBlock(id: operands[2]);
1640
1641 std::optional<std::pair<uint32_t, uint32_t>> weights;
1642 if (operands.size() == 5) {
1643 weights = std::make_pair(x: operands[3], y: operands[4]);
1644 }
1645 // The preceding instruction for the OpBranchConditional instruction could be
1646 // an OpSelectionMerge instruction, in this case they will have the same
1647 // OpLine information.
1648 auto loc = createFileLineColLoc(opBuilder);
1649 opBuilder.create<spirv::BranchConditionalOp>(
1650 loc, condition, trueBlock,
1651 /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1652 /*falseArguments=*/ArrayRef<Value>(), weights);
1653
1654 clearDebugLine();
1655 return success();
1656}
1657
1658LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1659 if (!curFunction) {
1660 return emitError(loc: unknownLoc, message: "OpLabel must appear inside a function");
1661 }
1662
1663 if (operands.size() != 1) {
1664 return emitError(loc: unknownLoc, message: "OpLabel should only have result <id>");
1665 }
1666
1667 auto labelID = operands[0];
1668 // We may have forward declared this block.
1669 auto *block = getOrCreateBlock(id: labelID);
1670 LLVM_DEBUG(logger.startLine()
1671 << "[block] populating block " << block << "\n");
1672 // If we have seen this block, make sure it was just a forward declaration.
1673 assert(block->empty() && "re-deserialize the same block!");
1674
1675 opBuilder.setInsertionPointToStart(block);
1676 blockMap[labelID] = curBlock = block;
1677
1678 return success();
1679}
1680
1681LogicalResult
1682spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1683 if (!curBlock) {
1684 return emitError(loc: unknownLoc, message: "OpSelectionMerge must appear in a block");
1685 }
1686
1687 if (operands.size() < 2) {
1688 return emitError(
1689 loc: unknownLoc,
1690 message: "OpSelectionMerge must specify merge target and selection control");
1691 }
1692
1693 auto *mergeBlock = getOrCreateBlock(id: operands[0]);
1694 auto loc = createFileLineColLoc(opBuilder);
1695 auto selectionControl = operands[1];
1696
1697 if (!blockMergeInfo.try_emplace(Key: curBlock, Args&: loc, Args&: selectionControl, Args&: mergeBlock)
1698 .second) {
1699 return emitError(
1700 loc: unknownLoc,
1701 message: "a block cannot have more than one OpSelectionMerge instruction");
1702 }
1703
1704 return success();
1705}
1706
1707LogicalResult
1708spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1709 if (!curBlock) {
1710 return emitError(loc: unknownLoc, message: "OpLoopMerge must appear in a block");
1711 }
1712
1713 if (operands.size() < 3) {
1714 return emitError(loc: unknownLoc, message: "OpLoopMerge must specify merge target, "
1715 "continue target and loop control");
1716 }
1717
1718 auto *mergeBlock = getOrCreateBlock(id: operands[0]);
1719 auto *continueBlock = getOrCreateBlock(id: operands[1]);
1720 auto loc = createFileLineColLoc(opBuilder);
1721 uint32_t loopControl = operands[2];
1722
1723 if (!blockMergeInfo
1724 .try_emplace(Key: curBlock, Args&: loc, Args&: loopControl, Args&: mergeBlock, Args&: continueBlock)
1725 .second) {
1726 return emitError(
1727 loc: unknownLoc,
1728 message: "a block cannot have more than one OpLoopMerge instruction");
1729 }
1730
1731 return success();
1732}
1733
1734LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1735 if (!curBlock) {
1736 return emitError(loc: unknownLoc, message: "OpPhi must appear in a block");
1737 }
1738
1739 if (operands.size() < 4) {
1740 return emitError(loc: unknownLoc, message: "OpPhi must specify result type, result <id>, "
1741 "and variable-parent pairs");
1742 }
1743
1744 // Create a block argument for this OpPhi instruction.
1745 Type blockArgType = getType(id: operands[0]);
1746 BlockArgument blockArg = curBlock->addArgument(type: blockArgType, loc: unknownLoc);
1747 valueMap[operands[1]] = blockArg;
1748 LLVM_DEBUG(logger.startLine()
1749 << "[phi] created block argument " << blockArg
1750 << " id = " << operands[1] << " of type " << blockArgType << "\n");
1751
1752 // For each (value, predecessor) pair, insert the value to the predecessor's
1753 // blockPhiInfo entry so later we can fix the block argument there.
1754 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1755 uint32_t value = operands[i];
1756 Block *predecessor = getOrCreateBlock(id: operands[i + 1]);
1757 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1758 blockPhiInfo[predecessorTargetPair].push_back(Elt: value);
1759 LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1760 << " with arg id = " << value << "\n");
1761 }
1762
1763 return success();
1764}
1765
1766namespace {
1767/// A class for putting all blocks in a structured selection/loop in a
1768/// spirv.mlir.selection/spirv.mlir.loop op.
1769class ControlFlowStructurizer {
1770public:
1771#ifndef NDEBUG
1772 ControlFlowStructurizer(Location loc, uint32_t control,
1773 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1774 Block *merge, Block *cont,
1775 llvm::ScopedPrinter &logger)
1776 : location(loc), control(control), blockMergeInfo(mergeInfo),
1777 headerBlock(header), mergeBlock(merge), continueBlock(cont),
1778 logger(logger) {}
1779#else
1780 ControlFlowStructurizer(Location loc, uint32_t control,
1781 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1782 Block *merge, Block *cont)
1783 : location(loc), control(control), blockMergeInfo(mergeInfo),
1784 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1785#endif
1786
1787 /// Structurizes the loop at the given `headerBlock`.
1788 ///
1789 /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1790 /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1791 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1792 /// method will also update `mergeInfo` by remapping all blocks inside to the
1793 /// newly cloned ones inside structured control flow op's regions.
1794 LogicalResult structurize();
1795
1796private:
1797 /// Creates a new spirv.mlir.selection op at the beginning of the
1798 /// `mergeBlock`.
1799 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1800
1801 /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1802 spirv::LoopOp createLoopOp(uint32_t loopControl);
1803
1804 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1805 void collectBlocksInConstruct();
1806
1807 Location location;
1808 uint32_t control;
1809
1810 spirv::BlockMergeInfoMap &blockMergeInfo;
1811
1812 Block *headerBlock;
1813 Block *mergeBlock;
1814 Block *continueBlock; // nullptr for spirv.mlir.selection
1815
1816 SetVector<Block *> constructBlocks;
1817
1818#ifndef NDEBUG
1819 /// A logger used to emit information during the deserialzation process.
1820 llvm::ScopedPrinter &logger;
1821#endif
1822};
1823} // namespace
1824
1825spirv::SelectionOp
1826ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1827 // Create a builder and set the insertion point to the beginning of the
1828 // merge block so that the newly created SelectionOp will be inserted there.
1829 OpBuilder builder(&mergeBlock->front());
1830
1831 auto control = static_cast<spirv::SelectionControl>(selectionControl);
1832 auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1833 selectionOp.addMergeBlock(builder);
1834
1835 return selectionOp;
1836}
1837
1838spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1839 // Create a builder and set the insertion point to the beginning of the
1840 // merge block so that the newly created LoopOp will be inserted there.
1841 OpBuilder builder(&mergeBlock->front());
1842
1843 auto control = static_cast<spirv::LoopControl>(loopControl);
1844 auto loopOp = builder.create<spirv::LoopOp>(location, control);
1845 loopOp.addEntryAndMergeBlock(builder);
1846
1847 return loopOp;
1848}
1849
1850void ControlFlowStructurizer::collectBlocksInConstruct() {
1851 assert(constructBlocks.empty() && "expected empty constructBlocks");
1852
1853 // Put the header block in the work list first.
1854 constructBlocks.insert(X: headerBlock);
1855
1856 // For each item in the work list, add its successors excluding the merge
1857 // block.
1858 for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1859 for (auto *successor : constructBlocks[i]->getSuccessors())
1860 if (successor != mergeBlock)
1861 constructBlocks.insert(X: successor);
1862 }
1863}
1864
1865LogicalResult ControlFlowStructurizer::structurize() {
1866 Operation *op = nullptr;
1867 bool isLoop = continueBlock != nullptr;
1868 if (isLoop) {
1869 if (auto loopOp = createLoopOp(control))
1870 op = loopOp.getOperation();
1871 } else {
1872 if (auto selectionOp = createSelectionOp(control))
1873 op = selectionOp.getOperation();
1874 }
1875 if (!op)
1876 return failure();
1877 Region &body = op->getRegion(index: 0);
1878
1879 IRMapping mapper;
1880 // All references to the old merge block should be directed to the
1881 // selection/loop merge block in the SelectionOp/LoopOp's region.
1882 mapper.map(from: mergeBlock, to: &body.back());
1883
1884 collectBlocksInConstruct();
1885
1886 // We've identified all blocks belonging to the selection/loop's region. Now
1887 // need to "move" them into the selection/loop. Instead of really moving the
1888 // blocks, in the following we copy them and remap all values and branches.
1889 // This is because:
1890 // * Inserting a block into a region requires the block not in any region
1891 // before. But selections/loops can nest so we can create selection/loop ops
1892 // in a nested manner, which means some blocks may already be in a
1893 // selection/loop region when to be moved again.
1894 // * It's much trickier to fix up the branches into and out of the loop's
1895 // region: we need to treat not-moved blocks and moved blocks differently:
1896 // Not-moved blocks jumping to the loop header block need to jump to the
1897 // merge point containing the new loop op but not the loop continue block's
1898 // back edge. Moved blocks jumping out of the loop need to jump to the
1899 // merge block inside the loop region but not other not-moved blocks.
1900 // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1901 // logic.
1902
1903 // Create a corresponding block in the SelectionOp/LoopOp's region for each
1904 // block in this loop construct.
1905 OpBuilder builder(body);
1906 for (auto *block : constructBlocks) {
1907 // Create a block and insert it before the selection/loop merge block in the
1908 // SelectionOp/LoopOp's region.
1909 auto *newBlock = builder.createBlock(insertBefore: &body.back());
1910 mapper.map(from: block, to: newBlock);
1911 LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1912 << " from block " << block << "\n");
1913 if (!isFnEntryBlock(block)) {
1914 for (BlockArgument blockArg : block->getArguments()) {
1915 auto newArg =
1916 newBlock->addArgument(type: blockArg.getType(), loc: blockArg.getLoc());
1917 mapper.map(from: blockArg, to: newArg);
1918 LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1919 << blockArg << " to " << newArg << "\n");
1920 }
1921 } else {
1922 LLVM_DEBUG(logger.startLine()
1923 << "[cf] block " << block << " is a function entry block\n");
1924 }
1925
1926 for (auto &op : *block)
1927 newBlock->push_back(op: op.clone(mapper));
1928 }
1929
1930 // Go through all ops and remap the operands.
1931 auto remapOperands = [&](Operation *op) {
1932 for (auto &operand : op->getOpOperands())
1933 if (Value mappedOp = mapper.lookupOrNull(from: operand.get()))
1934 operand.set(mappedOp);
1935 for (auto &succOp : op->getBlockOperands())
1936 if (Block *mappedOp = mapper.lookupOrNull(from: succOp.get()))
1937 succOp.set(mappedOp);
1938 };
1939 for (auto &block : body)
1940 block.walk(callback&: remapOperands);
1941
1942 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1943 // the selection/loop construct into its region. Next we need to fix the
1944 // connections between this new SelectionOp/LoopOp with existing blocks.
1945
1946 // All existing incoming branches should go to the merge block, where the
1947 // SelectionOp/LoopOp resides right now.
1948 headerBlock->replaceAllUsesWith(newValue&: mergeBlock);
1949
1950 LLVM_DEBUG({
1951 logger.startLine() << "[cf] after cloning and fixing references:\n";
1952 headerBlock->getParentOp()->print(logger.getOStream());
1953 logger.startLine() << "\n";
1954 });
1955
1956 if (isLoop) {
1957 if (!mergeBlock->args_empty()) {
1958 return mergeBlock->getParentOp()->emitError(
1959 message: "OpPhi in loop merge block unsupported");
1960 }
1961
1962 // The loop header block may have block arguments. Since now we place the
1963 // loop op inside the old merge block, we need to make sure the old merge
1964 // block has the same block argument list.
1965 for (BlockArgument blockArg : headerBlock->getArguments())
1966 mergeBlock->addArgument(type: blockArg.getType(), loc: blockArg.getLoc());
1967
1968 // If the loop header block has block arguments, make sure the spirv.Branch
1969 // op matches.
1970 SmallVector<Value, 4> blockArgs;
1971 if (!headerBlock->args_empty())
1972 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1973
1974 // The loop entry block should have a unconditional branch jumping to the
1975 // loop header block.
1976 builder.setInsertionPointToEnd(&body.front());
1977 builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1978 ArrayRef<Value>(blockArgs));
1979 }
1980
1981 // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1982 // cleaned up.
1983 LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
1984 // First we need to drop all operands' references inside all blocks. This is
1985 // needed because we can have blocks referencing SSA values from one another.
1986 for (auto *block : constructBlocks)
1987 block->dropAllReferences();
1988
1989 // Check that whether some op in the to-be-erased blocks still has uses. Those
1990 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
1991 // region. We cannot handle such cases given that once a value is sinked into
1992 // the SelectionOp/LoopOp's region, there is no escape for it:
1993 // SelectionOp/LooOp does not support yield values right now.
1994 for (auto *block : constructBlocks) {
1995 for (Operation &op : *block)
1996 if (!op.use_empty())
1997 return op.emitOpError(
1998 message: "failed control flow structurization: it has uses outside of the "
1999 "enclosing selection/loop construct");
2000 }
2001
2002 // Then erase all old blocks.
2003 for (auto *block : constructBlocks) {
2004 // We've cloned all blocks belonging to this construct into the structured
2005 // control flow op's region. Among these blocks, some may compose another
2006 // selection/loop. If so, they will be recorded within blockMergeInfo.
2007 // We need to update the pointers there to the newly remapped ones so we can
2008 // continue structurizing them later.
2009 // TODO: The asserts in the following assumes input SPIR-V blob forms
2010 // correctly nested selection/loop constructs. We should relax this and
2011 // support error cases better.
2012 auto it = blockMergeInfo.find(Val: block);
2013 if (it != blockMergeInfo.end()) {
2014 // Use the original location for nested selection/loop ops.
2015 Location loc = it->second.loc;
2016
2017 Block *newHeader = mapper.lookupOrNull(from: block);
2018 if (!newHeader)
2019 return emitError(loc, message: "failed control flow structurization: nested "
2020 "loop header block should be remapped!");
2021
2022 Block *newContinue = it->second.continueBlock;
2023 if (newContinue) {
2024 newContinue = mapper.lookupOrNull(from: newContinue);
2025 if (!newContinue)
2026 return emitError(loc, message: "failed control flow structurization: nested "
2027 "loop continue block should be remapped!");
2028 }
2029
2030 Block *newMerge = it->second.mergeBlock;
2031 if (Block *mappedTo = mapper.lookupOrNull(from: newMerge))
2032 newMerge = mappedTo;
2033
2034 // The iterator should be erased before adding a new entry into
2035 // blockMergeInfo to avoid iterator invalidation.
2036 blockMergeInfo.erase(I: it);
2037 blockMergeInfo.try_emplace(Key: newHeader, Args&: loc, Args&: it->second.control, Args&: newMerge,
2038 Args&: newContinue);
2039 }
2040
2041 // The structured selection/loop's entry block does not have arguments.
2042 // If the function's header block is also part of the structured control
2043 // flow, we cannot just simply erase it because it may contain arguments
2044 // matching the function signature and used by the cloned blocks.
2045 if (isFnEntryBlock(block)) {
2046 LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2047 << " to only contain a spirv.Branch op\n");
2048 // Still keep the function entry block for the potential block arguments,
2049 // but replace all ops inside with a branch to the merge block.
2050 block->clear();
2051 builder.setInsertionPointToEnd(block);
2052 builder.create<spirv::BranchOp>(location, mergeBlock);
2053 } else {
2054 LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2055 block->erase();
2056 }
2057 }
2058
2059 LLVM_DEBUG(logger.startLine()
2060 << "[cf] after structurizing construct with header block "
2061 << headerBlock << ":\n"
2062 << *op << "\n");
2063
2064 return success();
2065}
2066
2067LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2068 LLVM_DEBUG({
2069 logger.startLine()
2070 << "//----- [phi] start wiring up block arguments -----//\n";
2071 logger.indent();
2072 });
2073
2074 OpBuilder::InsertionGuard guard(opBuilder);
2075
2076 for (const auto &info : blockPhiInfo) {
2077 Block *block = info.first.first;
2078 Block *target = info.first.second;
2079 const BlockPhiInfo &phiInfo = info.second;
2080 LLVM_DEBUG({
2081 logger.startLine() << "[phi] block " << block << "\n";
2082 logger.startLine() << "[phi] before creating block argument:\n";
2083 block->getParentOp()->print(logger.getOStream());
2084 logger.startLine() << "\n";
2085 });
2086
2087 // Set insertion point to before this block's terminator early because we
2088 // may materialize ops via getValue() call.
2089 auto *op = block->getTerminator();
2090 opBuilder.setInsertionPoint(op);
2091
2092 SmallVector<Value, 4> blockArgs;
2093 blockArgs.reserve(N: phiInfo.size());
2094 for (uint32_t valueId : phiInfo) {
2095 if (Value value = getValue(id: valueId)) {
2096 blockArgs.push_back(Elt: value);
2097 LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2098 << " id = " << valueId << "\n");
2099 } else {
2100 return emitError(loc: unknownLoc, message: "OpPhi references undefined value!");
2101 }
2102 }
2103
2104 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2105 // Replace the previous branch op with a new one with block arguments.
2106 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2107 blockArgs);
2108 branchOp.erase();
2109 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2110 assert((branchCondOp.getTrueBlock() == target ||
2111 branchCondOp.getFalseBlock() == target) &&
2112 "expected target to be either the true or false target");
2113 if (target == branchCondOp.getTrueTarget())
2114 opBuilder.create<spirv::BranchConditionalOp>(
2115 branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2116 branchCondOp.getFalseBlockArguments(),
2117 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2118 branchCondOp.getFalseTarget());
2119 else
2120 opBuilder.create<spirv::BranchConditionalOp>(
2121 branchCondOp.getLoc(), branchCondOp.getCondition(),
2122 branchCondOp.getTrueBlockArguments(), blockArgs,
2123 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2124 branchCondOp.getFalseBlock());
2125
2126 branchCondOp.erase();
2127 } else {
2128 return emitError(loc: unknownLoc, message: "unimplemented terminator for Phi creation");
2129 }
2130
2131 LLVM_DEBUG({
2132 logger.startLine() << "[phi] after creating block argument:\n";
2133 block->getParentOp()->print(logger.getOStream());
2134 logger.startLine() << "\n";
2135 });
2136 }
2137 blockPhiInfo.clear();
2138
2139 LLVM_DEBUG({
2140 logger.unindent();
2141 logger.startLine()
2142 << "//--- [phi] completed wiring up block arguments ---//\n";
2143 });
2144 return success();
2145}
2146
2147LogicalResult spirv::Deserializer::structurizeControlFlow() {
2148 LLVM_DEBUG({
2149 logger.startLine()
2150 << "//----- [cf] start structurizing control flow -----//\n";
2151 logger.indent();
2152 });
2153
2154 while (!blockMergeInfo.empty()) {
2155 Block *headerBlock = blockMergeInfo.begin()->first;
2156 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2157
2158 LLVM_DEBUG({
2159 logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2160 headerBlock->print(logger.getOStream());
2161 logger.startLine() << "\n";
2162 });
2163
2164 auto *mergeBlock = mergeInfo.mergeBlock;
2165 assert(mergeBlock && "merge block cannot be nullptr");
2166 if (!mergeBlock->args_empty())
2167 return emitError(loc: unknownLoc, message: "OpPhi in loop merge block unimplemented");
2168 LLVM_DEBUG({
2169 logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2170 mergeBlock->print(logger.getOStream());
2171 logger.startLine() << "\n";
2172 });
2173
2174 auto *continueBlock = mergeInfo.continueBlock;
2175 LLVM_DEBUG(if (continueBlock) {
2176 logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2177 continueBlock->print(logger.getOStream());
2178 logger.startLine() << "\n";
2179 });
2180 // Erase this case before calling into structurizer, who will update
2181 // blockMergeInfo.
2182 blockMergeInfo.erase(I: blockMergeInfo.begin());
2183 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2184 blockMergeInfo, headerBlock,
2185 mergeBlock, continueBlock
2186#ifndef NDEBUG
2187 ,
2188 logger
2189#endif
2190 );
2191 if (failed(result: structurizer.structurize()))
2192 return failure();
2193 }
2194
2195 LLVM_DEBUG({
2196 logger.unindent();
2197 logger.startLine()
2198 << "//--- [cf] completed structurizing control flow ---//\n";
2199 });
2200 return success();
2201}
2202
2203//===----------------------------------------------------------------------===//
2204// Debug
2205//===----------------------------------------------------------------------===//
2206
2207Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2208 if (!debugLine)
2209 return unknownLoc;
2210
2211 auto fileName = debugInfoMap.lookup(Val: debugLine->fileID).str();
2212 if (fileName.empty())
2213 fileName = "<unknown>";
2214 return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2215 debugLine->column);
2216}
2217
2218LogicalResult
2219spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2220 // According to SPIR-V spec:
2221 // "This location information applies to the instructions physically
2222 // following this instruction, up to the first occurrence of any of the
2223 // following: the next end of block, the next OpLine instruction, or the next
2224 // OpNoLine instruction."
2225 if (operands.size() != 3)
2226 return emitError(loc: unknownLoc, message: "OpLine must have 3 operands");
2227 debugLine = DebugLine{.fileID: operands[0], .line: operands[1], .column: operands[2]};
2228 return success();
2229}
2230
2231void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2232
2233LogicalResult
2234spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2235 if (operands.size() < 2)
2236 return emitError(loc: unknownLoc, message: "OpString needs at least 2 operands");
2237
2238 if (!debugInfoMap.lookup(Val: operands[0]).empty())
2239 return emitError(loc: unknownLoc,
2240 message: "duplicate debug string found for result <id> ")
2241 << operands[0];
2242
2243 unsigned wordIndex = 1;
2244 StringRef debugString = decodeStringLiteral(words: operands, wordIndex);
2245 if (wordIndex != operands.size())
2246 return emitError(loc: unknownLoc,
2247 message: "unexpected trailing words in OpString instruction");
2248
2249 debugInfoMap[operands[0]] = debugString;
2250 return success();
2251}
2252

source code of mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp