1//===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Deserializer.h"
14
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
22#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Sequence.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/raw_ostream.h"
31#include <optional>
32
33using namespace mlir;
34
35#define DEBUG_TYPE "spirv-deserialization"
36
37//===----------------------------------------------------------------------===//
38// Utility Functions
39//===----------------------------------------------------------------------===//
40
41/// Returns true if the given `block` is a function entry block.
42static inline bool isFnEntryBlock(Block *block) {
43 return block->isEntryBlock() &&
44 isa_and_nonnull<spirv::FuncOp>(Val: block->getParentOp());
45}
46
47//===----------------------------------------------------------------------===//
48// Deserializer Method Definitions
49//===----------------------------------------------------------------------===//
50
51spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
52 MLIRContext *context,
53 const spirv::DeserializationOptions &options)
54 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()), options(options)
56#ifndef NDEBUG
57 ,
58 logger(llvm::dbgs())
59#endif
60{
61}
62
63LogicalResult spirv::Deserializer::deserialize() {
64 LLVM_DEBUG({
65 logger.resetIndent();
66 logger.startLine()
67 << "//+++---------- start deserialization ----------+++//\n";
68 });
69
70 if (failed(Result: processHeader()))
71 return failure();
72
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
77 // Slice the next instruction out and populate `opcode` and `operands`.
78 // Internally this also updates `curOffset`.
79 if (failed(Result: sliceInstruction(opcode, operands)))
80 return failure();
81
82 if (failed(Result: processInstruction(opcode, operands)))
83 return failure();
84 }
85
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
88
89 for (auto &deferred : deferredInstructions) {
90 if (failed(Result: processInstruction(opcode: deferred.first, operands: deferred.second, deferInstructions: false))) {
91 return failure();
92 }
93 }
94
95 attachVCETriple();
96
97 LLVM_DEBUG(logger.startLine()
98 << "//+++-------- completed deserialization --------+++//\n");
99 return success();
100}
101
102OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
103 return std::move(module);
104}
105
106//===----------------------------------------------------------------------===//
107// Module structure
108//===----------------------------------------------------------------------===//
109
110OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111 OpBuilder builder(context);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(odsBuilder&: builder, odsState&: state);
114 return cast<spirv::ModuleOp>(Val: Operation::create(state));
115}
116
117LogicalResult spirv::Deserializer::processHeader() {
118 if (binary.size() < spirv::kHeaderWordCount)
119 return emitError(loc: unknownLoc,
120 message: "SPIR-V binary module must have a 5-word header");
121
122 if (binary[0] != spirv::kMagicNumber)
123 return emitError(loc: unknownLoc, message: "incorrect magic number");
124
125 // Version number bytes: 0 | major number | minor number | 0
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130#define MIN_VERSION_CASE(v) \
131 case v: \
132 version = spirv::Version::V_1_##v; \
133 break
134
135 MIN_VERSION_CASE(0);
136 MIN_VERSION_CASE(1);
137 MIN_VERSION_CASE(2);
138 MIN_VERSION_CASE(3);
139 MIN_VERSION_CASE(4);
140 MIN_VERSION_CASE(5);
141#undef MIN_VERSION_CASE
142 default:
143 return emitError(loc: unknownLoc, message: "unsupported SPIR-V minor version: ")
144 << minorVersion;
145 }
146 } else {
147 return emitError(loc: unknownLoc, message: "unsupported SPIR-V major version: ")
148 << majorVersion;
149 }
150
151 // TODO: generator number, bound, schema
152 curOffset = spirv::kHeaderWordCount;
153 return success();
154}
155
156LogicalResult
157spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
158 if (operands.size() != 1)
159 return emitError(loc: unknownLoc, message: "OpCapability must have one parameter");
160
161 auto cap = spirv::symbolizeCapability(operands[0]);
162 if (!cap)
163 return emitError(loc: unknownLoc, message: "unknown capability: ") << operands[0];
164
165 capabilities.insert(X: *cap);
166 return success();
167}
168
169LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
170 if (words.empty()) {
171 return emitError(
172 loc: unknownLoc,
173 message: "OpExtension must have a literal string for the extension name");
174 }
175
176 unsigned wordIndex = 0;
177 StringRef extName = decodeStringLiteral(words, wordIndex);
178 if (wordIndex != words.size())
179 return emitError(loc: unknownLoc,
180 message: "unexpected trailing words in OpExtension instruction");
181 auto ext = spirv::symbolizeExtension(extName);
182 if (!ext)
183 return emitError(loc: unknownLoc, message: "unknown extension: ") << extName;
184
185 extensions.insert(X: *ext);
186 return success();
187}
188
189LogicalResult
190spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
191 if (words.size() < 2) {
192 return emitError(loc: unknownLoc,
193 message: "OpExtInstImport must have a result <id> and a literal "
194 "string for the extended instruction set name");
195 }
196
197 unsigned wordIndex = 1;
198 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
199 if (wordIndex != words.size()) {
200 return emitError(loc: unknownLoc,
201 message: "unexpected trailing words in OpExtInstImport");
202 }
203 return success();
204}
205
206void spirv::Deserializer::attachVCETriple() {
207 (*module)->setAttr(
208 name: spirv::ModuleOp::getVCETripleAttrName(),
209 value: spirv::VerCapExtAttr::get(version, capabilities: capabilities.getArrayRef(),
210 extensions: extensions.getArrayRef(), context));
211}
212
213LogicalResult
214spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
215 if (operands.size() != 2)
216 return emitError(loc: unknownLoc, message: "OpMemoryModel must have two operands");
217
218 (*module)->setAttr(
219 name: module->getAddressingModelAttrName(),
220 value: opBuilder.getAttr<spirv::AddressingModelAttr>(
221 args: static_cast<spirv::AddressingModel>(operands.front())));
222
223 (*module)->setAttr(name: module->getMemoryModelAttrName(),
224 value: opBuilder.getAttr<spirv::MemoryModelAttr>(
225 args: static_cast<spirv::MemoryModel>(operands.back())));
226
227 return success();
228}
229
230template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
231LogicalResult deserializeCacheControlDecoration(
232 Location loc, OpBuilder &opBuilder,
233 DenseMap<uint32_t, NamedAttrList> &decorations, ArrayRef<uint32_t> words,
234 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
235 if (words.size() != 4) {
236 return emitError(loc, message: "OpDecoration with ")
237 << decorationName << "needs a cache control integer literal and a "
238 << cacheControlKind << " cache control literal";
239 }
240 unsigned cacheLevel = words[2];
241 auto cacheControlAttr = static_cast<EnumTy>(words[3]);
242 auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
243 SmallVector<Attribute> attrs;
244 if (auto attrList =
245 llvm::dyn_cast_or_null<ArrayAttr>(Val: decorations[words[0]].get(name: symbol)))
246 llvm::append_range(C&: attrs, R&: attrList);
247 attrs.push_back(Elt: value);
248 decorations[words[0]].set(name: symbol, value: opBuilder.getArrayAttr(value: attrs));
249 return success();
250}
251
252LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
253 // TODO: This function should also be auto-generated. For now, since only a
254 // few decorations are processed/handled in a meaningful manner, going with a
255 // manual implementation.
256 if (words.size() < 2) {
257 return emitError(
258 loc: unknownLoc, message: "OpDecorate must have at least result <id> and Decoration");
259 }
260 auto decorationName =
261 stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
262 if (decorationName.empty()) {
263 return emitError(loc: unknownLoc, message: "invalid Decoration code : ") << words[1];
264 }
265 auto symbol = getSymbolDecoration(decorationName);
266 switch (static_cast<spirv::Decoration>(words[1])) {
267 case spirv::Decoration::FPFastMathMode:
268 if (words.size() != 3) {
269 return emitError(loc: unknownLoc, message: "OpDecorate with ")
270 << decorationName << " needs a single integer literal";
271 }
272 decorations[words[0]].set(
273 name: symbol, value: FPFastMathModeAttr::get(context: opBuilder.getContext(),
274 value: static_cast<FPFastMathMode>(words[2])));
275 break;
276 case spirv::Decoration::FPRoundingMode:
277 if (words.size() != 3) {
278 return emitError(loc: unknownLoc, message: "OpDecorate with ")
279 << decorationName << " needs a single integer literal";
280 }
281 decorations[words[0]].set(
282 name: symbol, value: FPRoundingModeAttr::get(context: opBuilder.getContext(),
283 value: static_cast<FPRoundingMode>(words[2])));
284 break;
285 case spirv::Decoration::DescriptorSet:
286 case spirv::Decoration::Binding:
287 if (words.size() != 3) {
288 return emitError(loc: unknownLoc, message: "OpDecorate with ")
289 << decorationName << " needs a single integer literal";
290 }
291 decorations[words[0]].set(
292 name: symbol, value: opBuilder.getI32IntegerAttr(value: static_cast<int32_t>(words[2])));
293 break;
294 case spirv::Decoration::BuiltIn:
295 if (words.size() != 3) {
296 return emitError(loc: unknownLoc, message: "OpDecorate with ")
297 << decorationName << " needs a single integer literal";
298 }
299 decorations[words[0]].set(
300 name: symbol, value: opBuilder.getStringAttr(
301 bytes: stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
302 break;
303 case spirv::Decoration::ArrayStride:
304 if (words.size() != 3) {
305 return emitError(loc: unknownLoc, message: "OpDecorate with ")
306 << decorationName << " needs a single integer literal";
307 }
308 typeDecorations[words[0]] = words[2];
309 break;
310 case spirv::Decoration::LinkageAttributes: {
311 if (words.size() < 4) {
312 return emitError(loc: unknownLoc, message: "OpDecorate with ")
313 << decorationName
314 << " needs at least 1 string and 1 integer literal";
315 }
316 // LinkageAttributes has two parameters ["linkageName", linkageType]
317 // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
318 // "linkageName" is a stringliteral encoded as uint32_t,
319 // hence the size of name is variable length which results in words.size()
320 // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
321 // 3 + ceildiv(strlen(name), 4).
322 unsigned wordIndex = 2;
323 auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
324 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
325 args: static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
326 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
327 args: StringAttr::get(context, bytes: linkageName), args&: linkageTypeAttr);
328 decorations[words[0]].set(name: symbol, value: llvm::dyn_cast<Attribute>(Val&: linkageAttr));
329 break;
330 }
331 case spirv::Decoration::Aliased:
332 case spirv::Decoration::AliasedPointer:
333 case spirv::Decoration::Block:
334 case spirv::Decoration::BufferBlock:
335 case spirv::Decoration::Flat:
336 case spirv::Decoration::NonReadable:
337 case spirv::Decoration::NonWritable:
338 case spirv::Decoration::NoPerspective:
339 case spirv::Decoration::NoSignedWrap:
340 case spirv::Decoration::NoUnsignedWrap:
341 case spirv::Decoration::RelaxedPrecision:
342 case spirv::Decoration::Restrict:
343 case spirv::Decoration::RestrictPointer:
344 case spirv::Decoration::NoContraction:
345 case spirv::Decoration::Constant:
346 if (words.size() != 2) {
347 return emitError(loc: unknownLoc, message: "OpDecoration with ")
348 << decorationName << "needs a single target <id>";
349 }
350 // Block decoration does not affect spirv.struct type, but is still stored
351 // for verification.
352 // TODO: Update StructType to contain this information since
353 // it is needed for many validation rules.
354 decorations[words[0]].set(name: symbol, value: opBuilder.getUnitAttr());
355 break;
356 case spirv::Decoration::Location:
357 case spirv::Decoration::SpecId:
358 if (words.size() != 3) {
359 return emitError(loc: unknownLoc, message: "OpDecoration with ")
360 << decorationName << "needs a single integer literal";
361 }
362 decorations[words[0]].set(
363 name: symbol, value: opBuilder.getI32IntegerAttr(value: static_cast<int32_t>(words[2])));
364 break;
365 case spirv::Decoration::CacheControlLoadINTEL: {
366 LogicalResult res = deserializeCacheControlDecoration<
367 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
368 loc: unknownLoc, opBuilder, decorations, words, symbol, decorationName,
369 cacheControlKind: "load");
370 if (failed(Result: res))
371 return res;
372 break;
373 }
374 case spirv::Decoration::CacheControlStoreINTEL: {
375 LogicalResult res = deserializeCacheControlDecoration<
376 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
377 loc: unknownLoc, opBuilder, decorations, words, symbol, decorationName,
378 cacheControlKind: "store");
379 if (failed(Result: res))
380 return res;
381 break;
382 }
383 default:
384 return emitError(loc: unknownLoc, message: "unhandled Decoration : '") << decorationName;
385 }
386 return success();
387}
388
389LogicalResult
390spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
391 // The binary layout of OpMemberDecorate is different comparing to OpDecorate
392 if (words.size() < 3) {
393 return emitError(loc: unknownLoc,
394 message: "OpMemberDecorate must have at least 3 operands");
395 }
396
397 auto decoration = static_cast<spirv::Decoration>(words[2]);
398 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
399 return emitError(loc: unknownLoc,
400 message: " missing offset specification in OpMemberDecorate with "
401 "Offset decoration");
402 }
403 ArrayRef<uint32_t> decorationOperands;
404 if (words.size() > 3) {
405 decorationOperands = words.slice(N: 3);
406 }
407 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
408 return success();
409}
410
411LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
412 if (words.size() < 3) {
413 return emitError(loc: unknownLoc, message: "OpMemberName must have at least 3 operands");
414 }
415 unsigned wordIndex = 2;
416 auto name = decodeStringLiteral(words, wordIndex);
417 if (wordIndex != words.size()) {
418 return emitError(loc: unknownLoc,
419 message: "unexpected trailing words in OpMemberName instruction");
420 }
421 memberNameMap[words[0]][words[1]] = name;
422 return success();
423}
424
425LogicalResult spirv::Deserializer::setFunctionArgAttrs(
426 uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
427 if (!decorations.contains(Val: argID)) {
428 argAttrs[argIndex] = DictionaryAttr::get(context, value: {});
429 return success();
430 }
431
432 spirv::DecorationAttr foundDecorationAttr;
433 for (NamedAttribute decAttr : decorations[argID]) {
434 for (auto decoration :
435 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
436 spirv::Decoration::AliasedPointer,
437 spirv::Decoration::RestrictPointer}) {
438
439 if (decAttr.getName() !=
440 getSymbolDecoration(decorationName: stringifyDecoration(decoration)))
441 continue;
442
443 if (foundDecorationAttr)
444 return emitError(loc: unknownLoc,
445 message: "more than one Aliased/Restrict decorations for "
446 "function argument with result <id> ")
447 << argID;
448
449 foundDecorationAttr = spirv::DecorationAttr::get(context, value: decoration);
450 break;
451 }
452
453 if (decAttr.getName() == getSymbolDecoration(decorationName: stringifyDecoration(
454 spirv::Decoration::RelaxedPrecision))) {
455 // TODO: Current implementation supports only one decoration per function
456 // parameter so RelaxedPrecision cannot be applied at the same time as,
457 // for example, Aliased/Restrict/etc. This should be relaxed to allow any
458 // combination of decoration allowed by the spec to be supported.
459 if (foundDecorationAttr)
460 return emitError(loc: unknownLoc, message: "already found a decoration for function "
461 "argument with result <id> ")
462 << argID;
463
464 foundDecorationAttr = spirv::DecorationAttr::get(
465 context, value: spirv::Decoration::RelaxedPrecision);
466 }
467 }
468
469 if (!foundDecorationAttr)
470 return emitError(loc: unknownLoc, message: "unimplemented decoration support for "
471 "function argument with result <id> ")
472 << argID;
473
474 NamedAttribute attr(StringAttr::get(context, bytes: spirv::DecorationAttr::name),
475 foundDecorationAttr);
476 argAttrs[argIndex] = DictionaryAttr::get(context, value: attr);
477 return success();
478}
479
480LogicalResult
481spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
482 if (curFunction) {
483 return emitError(loc: unknownLoc, message: "found function inside function");
484 }
485
486 // Get the result type
487 if (operands.size() != 4) {
488 return emitError(loc: unknownLoc, message: "OpFunction must have 4 parameters");
489 }
490 Type resultType = getType(id: operands[0]);
491 if (!resultType) {
492 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
493 << operands[0];
494 }
495
496 uint32_t fnID = operands[1];
497 if (funcMap.count(Val: fnID)) {
498 return emitError(loc: unknownLoc, message: "duplicate function definition/declaration");
499 }
500
501 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
502 if (!fnControl) {
503 return emitError(loc: unknownLoc, message: "unknown Function Control: ") << operands[2];
504 }
505
506 Type fnType = getType(id: operands[3]);
507 if (!fnType || !isa<FunctionType>(Val: fnType)) {
508 return emitError(loc: unknownLoc, message: "unknown function type from <id> ")
509 << operands[3];
510 }
511 auto functionType = cast<FunctionType>(Val&: fnType);
512
513 if ((isVoidType(type: resultType) && functionType.getNumResults() != 0) ||
514 (functionType.getNumResults() == 1 &&
515 functionType.getResult(i: 0) != resultType)) {
516 return emitError(loc: unknownLoc, message: "mismatch in function type ")
517 << functionType << " and return type " << resultType << " specified";
518 }
519
520 std::string fnName = getFunctionSymbol(id: fnID);
521 auto funcOp = opBuilder.create<spirv::FuncOp>(
522 location: unknownLoc, args&: fnName, args&: functionType, args&: fnControl.value());
523 // Processing other function attributes.
524 if (decorations.count(Val: fnID)) {
525 for (auto attr : decorations[fnID].getAttrs()) {
526 funcOp->setAttr(name: attr.getName(), value: attr.getValue());
527 }
528 }
529 curFunction = funcMap[fnID] = funcOp;
530 auto *entryBlock = funcOp.addEntryBlock();
531 LLVM_DEBUG({
532 logger.startLine()
533 << "//===-------------------------------------------===//\n";
534 logger.startLine() << "[fn] name: " << fnName << "\n";
535 logger.startLine() << "[fn] type: " << fnType << "\n";
536 logger.startLine() << "[fn] ID: " << fnID << "\n";
537 logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
538 logger.indent();
539 });
540
541 SmallVector<Attribute> argAttrs;
542 argAttrs.resize(N: functionType.getNumInputs());
543
544 // Parse the op argument instructions
545 if (functionType.getNumInputs()) {
546 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
547 auto argType = functionType.getInput(i);
548 spirv::Opcode opcode = spirv::Opcode::OpNop;
549 ArrayRef<uint32_t> operands;
550 if (failed(Result: sliceInstruction(opcode, operands,
551 expectedOpcode: spirv::Opcode::OpFunctionParameter))) {
552 return failure();
553 }
554 if (opcode != spirv::Opcode::OpFunctionParameter) {
555 return emitError(
556 loc: unknownLoc,
557 message: "missing OpFunctionParameter instruction for argument ")
558 << i;
559 }
560 if (operands.size() != 2) {
561 return emitError(
562 loc: unknownLoc,
563 message: "expected result type and result <id> for OpFunctionParameter");
564 }
565 auto argDefinedType = getType(id: operands[0]);
566 if (!argDefinedType || argDefinedType != argType) {
567 return emitError(loc: unknownLoc,
568 message: "mismatch in argument type between function type "
569 "definition ")
570 << functionType << " and argument type definition "
571 << argDefinedType << " at argument " << i;
572 }
573 if (getValue(id: operands[1])) {
574 return emitError(loc: unknownLoc, message: "duplicate definition of result <id> ")
575 << operands[1];
576 }
577 if (failed(Result: setFunctionArgAttrs(argID: operands[1], argAttrs, argIndex: i))) {
578 return failure();
579 }
580
581 auto argValue = funcOp.getArgument(idx: i);
582 valueMap[operands[1]] = argValue;
583 }
584 }
585
586 if (llvm::any_of(Range&: argAttrs, P: [](Attribute attr) {
587 auto argAttr = cast<DictionaryAttr>(Val&: attr);
588 return !argAttr.empty();
589 }))
590 funcOp.setArgAttrsAttr(ArrayAttr::get(context, value: argAttrs));
591
592 // entryBlock is needed to access the arguments, Once that is done, we can
593 // erase the block for functions with 'Import' LinkageAttributes, since these
594 // are essentially function declarations, so they have no body.
595 auto linkageAttr = funcOp.getLinkageAttributes();
596 auto hasImportLinkage =
597 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
598 spirv::LinkageType::Import);
599 if (hasImportLinkage)
600 funcOp.eraseBody();
601
602 // RAII guard to reset the insertion point to the module's region after
603 // deserializing the body of this function.
604 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
605
606 spirv::Opcode opcode = spirv::Opcode::OpNop;
607 ArrayRef<uint32_t> instOperands;
608
609 // Special handling for the entry block. We need to make sure it starts with
610 // an OpLabel instruction. The entry block takes the same parameters as the
611 // function. All other blocks do not take any parameter. We have already
612 // created the entry block, here we need to register it to the correct label
613 // <id>.
614 if (failed(Result: sliceInstruction(opcode, operands&: instOperands,
615 expectedOpcode: spirv::Opcode::OpFunctionEnd))) {
616 return failure();
617 }
618 if (opcode == spirv::Opcode::OpFunctionEnd) {
619 return processFunctionEnd(operands: instOperands);
620 }
621 if (opcode != spirv::Opcode::OpLabel) {
622 return emitError(loc: unknownLoc, message: "a basic block must start with OpLabel");
623 }
624 if (instOperands.size() != 1) {
625 return emitError(loc: unknownLoc, message: "OpLabel should only have result <id>");
626 }
627 blockMap[instOperands[0]] = entryBlock;
628 if (failed(Result: processLabel(operands: instOperands))) {
629 return failure();
630 }
631
632 // Then process all the other instructions in the function until we hit
633 // OpFunctionEnd.
634 while (succeeded(Result: sliceInstruction(opcode, operands&: instOperands,
635 expectedOpcode: spirv::Opcode::OpFunctionEnd)) &&
636 opcode != spirv::Opcode::OpFunctionEnd) {
637 if (failed(Result: processInstruction(opcode, operands: instOperands))) {
638 return failure();
639 }
640 }
641 if (opcode != spirv::Opcode::OpFunctionEnd) {
642 return failure();
643 }
644
645 return processFunctionEnd(operands: instOperands);
646}
647
648LogicalResult
649spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
650 // Process OpFunctionEnd.
651 if (!operands.empty()) {
652 return emitError(loc: unknownLoc, message: "unexpected operands for OpFunctionEnd");
653 }
654
655 // Wire up block arguments from OpPhi instructions.
656 // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
657 // ops.
658 if (failed(Result: wireUpBlockArgument()) || failed(Result: structurizeControlFlow())) {
659 return failure();
660 }
661
662 curBlock = nullptr;
663 curFunction = std::nullopt;
664
665 LLVM_DEBUG({
666 logger.unindent();
667 logger.startLine()
668 << "//===-------------------------------------------===//\n";
669 });
670 return success();
671}
672
673std::optional<std::pair<Attribute, Type>>
674spirv::Deserializer::getConstant(uint32_t id) {
675 auto constIt = constantMap.find(Val: id);
676 if (constIt == constantMap.end())
677 return std::nullopt;
678 return constIt->getSecond();
679}
680
681std::optional<std::pair<Attribute, Type>>
682spirv::Deserializer::getConstantCompositeReplicate(uint32_t id) {
683 if (auto it = constantCompositeReplicateMap.find(Val: id);
684 it != constantCompositeReplicateMap.end())
685 return it->second;
686 return std::nullopt;
687}
688
689std::optional<spirv::SpecConstOperationMaterializationInfo>
690spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
691 auto constIt = specConstOperationMap.find(Val: id);
692 if (constIt == specConstOperationMap.end())
693 return std::nullopt;
694 return constIt->getSecond();
695}
696
697std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
698 auto funcName = nameMap.lookup(Val: id).str();
699 if (funcName.empty()) {
700 funcName = "spirv_fn_" + std::to_string(val: id);
701 }
702 return funcName;
703}
704
705std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
706 auto constName = nameMap.lookup(Val: id).str();
707 if (constName.empty()) {
708 constName = "spirv_spec_const_" + std::to_string(val: id);
709 }
710 return constName;
711}
712
713spirv::SpecConstantOp
714spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
715 TypedAttr defaultValue) {
716 auto symName = opBuilder.getStringAttr(bytes: getSpecConstantSymbol(id: resultID));
717 auto op = opBuilder.create<spirv::SpecConstantOp>(location: unknownLoc, args&: symName,
718 args&: defaultValue);
719 if (decorations.count(Val: resultID)) {
720 for (auto attr : decorations[resultID].getAttrs())
721 op->setAttr(name: attr.getName(), value: attr.getValue());
722 }
723 specConstMap[resultID] = op;
724 return op;
725}
726
727LogicalResult
728spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
729 unsigned wordIndex = 0;
730 if (operands.size() < 3) {
731 return emitError(
732 loc: unknownLoc,
733 message: "OpVariable needs at least 3 operands, type, <id> and storage class");
734 }
735
736 // Result Type.
737 auto type = getType(id: operands[wordIndex]);
738 if (!type) {
739 return emitError(loc: unknownLoc, message: "unknown result type <id> : ")
740 << operands[wordIndex];
741 }
742 auto ptrType = dyn_cast<spirv::PointerType>(Val&: type);
743 if (!ptrType) {
744 return emitError(loc: unknownLoc,
745 message: "expected a result type <id> to be a spirv.ptr, found : ")
746 << type;
747 }
748 wordIndex++;
749
750 // Result <id>.
751 auto variableID = operands[wordIndex];
752 auto variableName = nameMap.lookup(Val: variableID).str();
753 if (variableName.empty()) {
754 variableName = "spirv_var_" + std::to_string(val: variableID);
755 }
756 wordIndex++;
757
758 // Storage class.
759 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
760 if (ptrType.getStorageClass() != storageClass) {
761 return emitError(loc: unknownLoc, message: "mismatch in storage class of pointer type ")
762 << type << " and that specified in OpVariable instruction : "
763 << stringifyStorageClass(storageClass);
764 }
765 wordIndex++;
766
767 // Initializer.
768 FlatSymbolRefAttr initializer = nullptr;
769
770 if (wordIndex < operands.size()) {
771 Operation *op = nullptr;
772
773 if (auto initOp = getGlobalVariable(id: operands[wordIndex]))
774 op = initOp;
775 else if (auto initOp = getSpecConstant(id: operands[wordIndex]))
776 op = initOp;
777 else if (auto initOp = getSpecConstantComposite(id: operands[wordIndex]))
778 op = initOp;
779 else
780 return emitError(loc: unknownLoc, message: "unknown <id> ")
781 << operands[wordIndex] << "used as initializer";
782
783 initializer = SymbolRefAttr::get(symbol: op);
784 wordIndex++;
785 }
786 if (wordIndex != operands.size()) {
787 return emitError(loc: unknownLoc,
788 message: "found more operands than expected when deserializing "
789 "OpVariable instruction, only ")
790 << wordIndex << " of " << operands.size() << " processed";
791 }
792 auto loc = createFileLineColLoc(opBuilder);
793 auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
794 location: loc, args: TypeAttr::get(type), args: opBuilder.getStringAttr(bytes: variableName),
795 args&: initializer);
796
797 // Decorations.
798 if (decorations.count(Val: variableID)) {
799 for (auto attr : decorations[variableID].getAttrs())
800 varOp->setAttr(name: attr.getName(), value: attr.getValue());
801 }
802 globalVariableMap[variableID] = varOp;
803 return success();
804}
805
806IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
807 auto constInfo = getConstant(id);
808 if (!constInfo) {
809 return nullptr;
810 }
811 return dyn_cast<IntegerAttr>(Val&: constInfo->first);
812}
813
814LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
815 if (operands.size() < 2) {
816 return emitError(loc: unknownLoc, message: "OpName needs at least 2 operands");
817 }
818 if (!nameMap.lookup(Val: operands[0]).empty()) {
819 return emitError(loc: unknownLoc, message: "duplicate name found for result <id> ")
820 << operands[0];
821 }
822 unsigned wordIndex = 1;
823 StringRef name = decodeStringLiteral(words: operands, wordIndex);
824 if (wordIndex != operands.size()) {
825 return emitError(loc: unknownLoc,
826 message: "unexpected trailing words in OpName instruction");
827 }
828 nameMap[operands[0]] = name;
829 return success();
830}
831
832//===----------------------------------------------------------------------===//
833// Type
834//===----------------------------------------------------------------------===//
835
836LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
837 ArrayRef<uint32_t> operands) {
838 if (operands.empty()) {
839 return emitError(loc: unknownLoc, message: "type instruction with opcode ")
840 << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
841 }
842
843 /// TODO: Types might be forward declared in some instructions and need to be
844 /// handled appropriately.
845 if (typeMap.count(Val: operands[0])) {
846 return emitError(loc: unknownLoc, message: "duplicate definition for result <id> ")
847 << operands[0];
848 }
849
850 switch (opcode) {
851 case spirv::Opcode::OpTypeVoid:
852 if (operands.size() != 1)
853 return emitError(loc: unknownLoc, message: "OpTypeVoid must have no parameters");
854 typeMap[operands[0]] = opBuilder.getNoneType();
855 break;
856 case spirv::Opcode::OpTypeBool:
857 if (operands.size() != 1)
858 return emitError(loc: unknownLoc, message: "OpTypeBool must have no parameters");
859 typeMap[operands[0]] = opBuilder.getI1Type();
860 break;
861 case spirv::Opcode::OpTypeInt: {
862 if (operands.size() != 3)
863 return emitError(
864 loc: unknownLoc, message: "OpTypeInt must have bitwidth and signedness parameters");
865
866 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
867 // to preserve or validate.
868 // 0 indicates unsigned, or no signedness semantics
869 // 1 indicates signed semantics."
870 //
871 // So we cannot differentiate signless and unsigned integers; always use
872 // signless semantics for such cases.
873 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
874 : IntegerType::SignednessSemantics::Signless;
875 typeMap[operands[0]] = IntegerType::get(context, width: operands[1], signedness: sign);
876 } break;
877 case spirv::Opcode::OpTypeFloat: {
878 if (operands.size() != 2 && operands.size() != 3)
879 return emitError(loc: unknownLoc,
880 message: "OpTypeFloat expects either 2 operands (type, bitwidth) "
881 "or 3 operands (type, bitwidth, encoding), but got ")
882 << operands.size();
883 uint32_t bitWidth = operands[1];
884
885 Type floatTy;
886 switch (bitWidth) {
887 case 16:
888 floatTy = opBuilder.getF16Type();
889 break;
890 case 32:
891 floatTy = opBuilder.getF32Type();
892 break;
893 case 64:
894 floatTy = opBuilder.getF64Type();
895 break;
896 default:
897 return emitError(loc: unknownLoc, message: "unsupported OpTypeFloat bitwidth: ")
898 << bitWidth;
899 }
900
901 if (operands.size() == 3) {
902 if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
903 return emitError(loc: unknownLoc, message: "unsupported OpTypeFloat FP encoding: ")
904 << operands[2];
905 if (bitWidth != 16)
906 return emitError(loc: unknownLoc,
907 message: "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
908 << bitWidth << " (expected 16)";
909 floatTy = opBuilder.getBF16Type();
910 }
911
912 typeMap[operands[0]] = floatTy;
913 } break;
914 case spirv::Opcode::OpTypeVector: {
915 if (operands.size() != 3) {
916 return emitError(
917 loc: unknownLoc,
918 message: "OpTypeVector must have element type and count parameters");
919 }
920 Type elementTy = getType(id: operands[1]);
921 if (!elementTy) {
922 return emitError(loc: unknownLoc, message: "OpTypeVector references undefined <id> ")
923 << operands[1];
924 }
925 typeMap[operands[0]] = VectorType::get(shape: {operands[2]}, elementType: elementTy);
926 } break;
927 case spirv::Opcode::OpTypePointer: {
928 return processOpTypePointer(operands);
929 } break;
930 case spirv::Opcode::OpTypeArray:
931 return processArrayType(operands);
932 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
933 return processCooperativeMatrixTypeKHR(operands);
934 case spirv::Opcode::OpTypeFunction:
935 return processFunctionType(operands);
936 case spirv::Opcode::OpTypeImage:
937 return processImageType(operands);
938 case spirv::Opcode::OpTypeSampledImage:
939 return processSampledImageType(operands);
940 case spirv::Opcode::OpTypeRuntimeArray:
941 return processRuntimeArrayType(operands);
942 case spirv::Opcode::OpTypeStruct:
943 return processStructType(operands);
944 case spirv::Opcode::OpTypeMatrix:
945 return processMatrixType(operands);
946 case spirv::Opcode::OpTypeTensorARM:
947 return processTensorARMType(operands);
948 default:
949 return emitError(loc: unknownLoc, message: "unhandled type instruction");
950 }
951 return success();
952}
953
954LogicalResult
955spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
956 if (operands.size() != 3)
957 return emitError(loc: unknownLoc, message: "OpTypePointer must have two parameters");
958
959 auto pointeeType = getType(id: operands[2]);
960 if (!pointeeType)
961 return emitError(loc: unknownLoc, message: "unknown OpTypePointer pointee type <id> ")
962 << operands[2];
963
964 uint32_t typePointerID = operands[0];
965 auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
966 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
967
968 for (auto *deferredStructIt = std::begin(cont&: deferredStructTypesInfos);
969 deferredStructIt != std::end(cont&: deferredStructTypesInfos);) {
970 for (auto *unresolvedMemberIt =
971 std::begin(cont&: deferredStructIt->unresolvedMemberTypes);
972 unresolvedMemberIt !=
973 std::end(cont&: deferredStructIt->unresolvedMemberTypes);) {
974 if (unresolvedMemberIt->first == typePointerID) {
975 // The newly constructed pointer type can resolve one of the
976 // deferred struct type members; update the memberTypes list and
977 // clean the unresolvedMemberTypes list accordingly.
978 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
979 typeMap[typePointerID];
980 unresolvedMemberIt =
981 deferredStructIt->unresolvedMemberTypes.erase(CI: unresolvedMemberIt);
982 } else {
983 ++unresolvedMemberIt;
984 }
985 }
986
987 if (deferredStructIt->unresolvedMemberTypes.empty()) {
988 // All deferred struct type members are now resolved, set the struct body.
989 auto structType = deferredStructIt->deferredStructType;
990
991 assert(structType && "expected a spirv::StructType");
992 assert(structType.isIdentified() && "expected an indentified struct");
993
994 if (failed(Result: structType.trySetBody(
995 memberTypes: deferredStructIt->memberTypes, offsetInfo: deferredStructIt->offsetInfo,
996 memberDecorations: deferredStructIt->memberDecorationsInfo)))
997 return failure();
998
999 deferredStructIt = deferredStructTypesInfos.erase(CI: deferredStructIt);
1000 } else {
1001 ++deferredStructIt;
1002 }
1003 }
1004
1005 return success();
1006}
1007
1008LogicalResult
1009spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
1010 if (operands.size() != 3) {
1011 return emitError(loc: unknownLoc,
1012 message: "OpTypeArray must have element type and count parameters");
1013 }
1014
1015 Type elementTy = getType(id: operands[1]);
1016 if (!elementTy) {
1017 return emitError(loc: unknownLoc, message: "OpTypeArray references undefined <id> ")
1018 << operands[1];
1019 }
1020
1021 unsigned count = 0;
1022 // TODO: The count can also come frome a specialization constant.
1023 auto countInfo = getConstant(id: operands[2]);
1024 if (!countInfo) {
1025 return emitError(loc: unknownLoc, message: "OpTypeArray count <id> ")
1026 << operands[2] << "can only come from normal constant right now";
1027 }
1028
1029 if (auto intVal = dyn_cast<IntegerAttr>(Val&: countInfo->first)) {
1030 count = intVal.getValue().getZExtValue();
1031 } else {
1032 return emitError(loc: unknownLoc, message: "OpTypeArray count must come from a "
1033 "scalar integer constant instruction");
1034 }
1035
1036 typeMap[operands[0]] = spirv::ArrayType::get(
1037 elementType: elementTy, elementCount: count, stride: typeDecorations.lookup(Val: operands[0]));
1038 return success();
1039}
1040
1041LogicalResult
1042spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
1043 assert(!operands.empty() && "No operands for processing function type");
1044 if (operands.size() == 1) {
1045 return emitError(loc: unknownLoc, message: "missing return type for OpTypeFunction");
1046 }
1047 auto returnType = getType(id: operands[1]);
1048 if (!returnType) {
1049 return emitError(loc: unknownLoc, message: "unknown return type in OpTypeFunction");
1050 }
1051 SmallVector<Type, 1> argTypes;
1052 for (size_t i = 2, e = operands.size(); i < e; ++i) {
1053 auto ty = getType(id: operands[i]);
1054 if (!ty) {
1055 return emitError(loc: unknownLoc, message: "unknown argument type in OpTypeFunction");
1056 }
1057 argTypes.push_back(Elt: ty);
1058 }
1059 ArrayRef<Type> returnTypes;
1060 if (!isVoidType(type: returnType)) {
1061 returnTypes = llvm::ArrayRef(returnType);
1062 }
1063 typeMap[operands[0]] = FunctionType::get(context, inputs: argTypes, results: returnTypes);
1064 return success();
1065}
1066
1067LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1068 ArrayRef<uint32_t> operands) {
1069 if (operands.size() != 6) {
1070 return emitError(loc: unknownLoc,
1071 message: "OpTypeCooperativeMatrixKHR must have element type, "
1072 "scope, row and column parameters, and use");
1073 }
1074
1075 Type elementTy = getType(id: operands[1]);
1076 if (!elementTy) {
1077 return emitError(loc: unknownLoc,
1078 message: "OpTypeCooperativeMatrixKHR references undefined <id> ")
1079 << operands[1];
1080 }
1081
1082 std::optional<spirv::Scope> scope =
1083 spirv::symbolizeScope(getConstantInt(id: operands[2]).getInt());
1084 if (!scope) {
1085 return emitError(
1086 loc: unknownLoc,
1087 message: "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1088 << operands[2];
1089 }
1090
1091 IntegerAttr rowsAttr = getConstantInt(id: operands[3]);
1092 IntegerAttr columnsAttr = getConstantInt(id: operands[4]);
1093 IntegerAttr useAttr = getConstantInt(id: operands[5]);
1094
1095 if (!rowsAttr)
1096 return emitError(loc: unknownLoc, message: "OpTypeCooperativeMatrixKHR `Rows` references "
1097 "undefined constant <id> ")
1098 << operands[3];
1099
1100 if (!columnsAttr)
1101 return emitError(loc: unknownLoc, message: "OpTypeCooperativeMatrixKHR `Columns` "
1102 "references undefined constant <id> ")
1103 << operands[4];
1104
1105 if (!useAttr)
1106 return emitError(loc: unknownLoc, message: "OpTypeCooperativeMatrixKHR `Use` references "
1107 "undefined constant <id> ")
1108 << operands[5];
1109
1110 unsigned rows = rowsAttr.getInt();
1111 unsigned columns = columnsAttr.getInt();
1112
1113 std::optional<spirv::CooperativeMatrixUseKHR> use =
1114 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1115 if (!use) {
1116 return emitError(
1117 loc: unknownLoc,
1118 message: "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1119 << operands[5];
1120 }
1121
1122 typeMap[operands[0]] =
1123 spirv::CooperativeMatrixType::get(elementType: elementTy, rows, columns, scope: *scope, use: *use);
1124 return success();
1125}
1126
1127LogicalResult
1128spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1129 if (operands.size() != 2) {
1130 return emitError(loc: unknownLoc, message: "OpTypeRuntimeArray must have two operands");
1131 }
1132 Type memberType = getType(id: operands[1]);
1133 if (!memberType) {
1134 return emitError(loc: unknownLoc,
1135 message: "OpTypeRuntimeArray references undefined <id> ")
1136 << operands[1];
1137 }
1138 typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1139 elementType: memberType, stride: typeDecorations.lookup(Val: operands[0]));
1140 return success();
1141}
1142
1143LogicalResult
1144spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1145 // TODO: Find a way to handle identified structs when debug info is stripped.
1146
1147 if (operands.empty()) {
1148 return emitError(loc: unknownLoc, message: "OpTypeStruct must have at least result <id>");
1149 }
1150
1151 if (operands.size() == 1) {
1152 // Handle empty struct.
1153 typeMap[operands[0]] =
1154 spirv::StructType::getEmpty(context, identifier: nameMap.lookup(Val: operands[0]).str());
1155 return success();
1156 }
1157
1158 // First element is operand ID, second element is member index in the struct.
1159 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1160 SmallVector<Type, 4> memberTypes;
1161
1162 for (auto op : llvm::drop_begin(RangeOrContainer&: operands, N: 1)) {
1163 Type memberType = getType(id: op);
1164 bool typeForwardPtr = (typeForwardPointerIDs.count(key: op) != 0);
1165
1166 if (!memberType && !typeForwardPtr)
1167 return emitError(loc: unknownLoc, message: "OpTypeStruct references undefined <id> ")
1168 << op;
1169
1170 if (!memberType)
1171 unresolvedMemberTypes.emplace_back(Args&: op, Args: memberTypes.size());
1172
1173 memberTypes.push_back(Elt: memberType);
1174 }
1175
1176 SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
1177 SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
1178 if (memberDecorationMap.count(Val: operands[0])) {
1179 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1180 for (auto memberIndex : llvm::seq<uint32_t>(Begin: 0, End: memberTypes.size())) {
1181 if (allMemberDecorations.count(Val: memberIndex)) {
1182 for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1183 // Check for offset.
1184 if (memberDecoration.first == spirv::Decoration::Offset) {
1185 // If offset info is empty, resize to the number of members;
1186 if (offsetInfo.empty()) {
1187 offsetInfo.resize(N: memberTypes.size());
1188 }
1189 offsetInfo[memberIndex] = memberDecoration.second[0];
1190 } else {
1191 if (!memberDecoration.second.empty()) {
1192 memberDecorationsInfo.emplace_back(Args&: memberIndex, /*hasValue=*/Args: 1,
1193 Args&: memberDecoration.first,
1194 Args: memberDecoration.second[0]);
1195 } else {
1196 memberDecorationsInfo.emplace_back(Args&: memberIndex, /*hasValue=*/Args: 0,
1197 Args&: memberDecoration.first, Args: 0);
1198 }
1199 }
1200 }
1201 }
1202 }
1203 }
1204
1205 uint32_t structID = operands[0];
1206 std::string structIdentifier = nameMap.lookup(Val: structID).str();
1207
1208 if (structIdentifier.empty()) {
1209 assert(unresolvedMemberTypes.empty() &&
1210 "didn't expect unresolved member types");
1211 typeMap[structID] =
1212 spirv::StructType::get(memberTypes, offsetInfo, memberDecorations: memberDecorationsInfo);
1213 } else {
1214 auto structTy = spirv::StructType::getIdentified(context, identifier: structIdentifier);
1215 typeMap[structID] = structTy;
1216
1217 if (!unresolvedMemberTypes.empty())
1218 deferredStructTypesInfos.push_back(Elt: {.deferredStructType: structTy, .unresolvedMemberTypes: unresolvedMemberTypes,
1219 .memberTypes: memberTypes, .offsetInfo: offsetInfo,
1220 .memberDecorationsInfo: memberDecorationsInfo});
1221 else if (failed(Result: structTy.trySetBody(memberTypes, offsetInfo,
1222 memberDecorations: memberDecorationsInfo)))
1223 return failure();
1224 }
1225
1226 // TODO: Update StructType to have member name as attribute as
1227 // well.
1228 return success();
1229}
1230
1231LogicalResult
1232spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1233 if (operands.size() != 3) {
1234 // Three operands are needed: result_id, column_type, and column_count
1235 return emitError(loc: unknownLoc, message: "OpTypeMatrix must have 3 operands"
1236 " (result_id, column_type, and column_count)");
1237 }
1238 // Matrix columns must be of vector type
1239 Type elementTy = getType(id: operands[1]);
1240 if (!elementTy) {
1241 return emitError(loc: unknownLoc,
1242 message: "OpTypeMatrix references undefined column type.")
1243 << operands[1];
1244 }
1245
1246 uint32_t colsCount = operands[2];
1247 typeMap[operands[0]] = spirv::MatrixType::get(columnType: elementTy, columnCount: colsCount);
1248 return success();
1249}
1250
1251LogicalResult
1252spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
1253 unsigned size = operands.size();
1254 if (size < 2 || size > 4)
1255 return emitError(loc: unknownLoc, message: "OpTypeTensorARM must have 2-4 operands "
1256 "(result_id, element_type, (rank), (shape)) ")
1257 << size;
1258
1259 Type elementTy = getType(id: operands[1]);
1260 if (!elementTy)
1261 return emitError(loc: unknownLoc,
1262 message: "OpTypeTensorARM references undefined element type ")
1263 << operands[1];
1264
1265 if (size == 2) {
1266 typeMap[operands[0]] = TensorArmType::get(shape: {}, elementType: elementTy);
1267 return success();
1268 }
1269
1270 IntegerAttr rankAttr = getConstantInt(id: operands[2]);
1271 if (!rankAttr)
1272 return emitError(loc: unknownLoc, message: "OpTypeTensorARM rank must come from a "
1273 "scalar integer constant instruction");
1274 unsigned rank = rankAttr.getValue().getZExtValue();
1275 if (size == 3) {
1276 SmallVector<int64_t, 4> shape(rank, ShapedType::kDynamic);
1277 typeMap[operands[0]] = TensorArmType::get(shape, elementType: elementTy);
1278 return success();
1279 }
1280
1281 std::optional<std::pair<Attribute, Type>> shapeInfo =
1282 getConstant(id: operands[3]);
1283 if (!shapeInfo)
1284 return emitError(loc: unknownLoc, message: "OpTypeTensorARM shape must come from a "
1285 "constant instruction of type OpTypeArray");
1286
1287 ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(Val&: shapeInfo->first);
1288 SmallVector<int64_t, 1> shape;
1289 for (auto dimAttr : shapeArrayAttr.getValue()) {
1290 auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(Val&: dimAttr);
1291 if (!dimIntAttr)
1292 return emitError(loc: unknownLoc, message: "OpTypeTensorARM shape has an invalid "
1293 "dimension size");
1294 shape.push_back(Elt: dimIntAttr.getValue().getSExtValue());
1295 }
1296 typeMap[operands[0]] = TensorArmType::get(shape, elementType: elementTy);
1297 return success();
1298}
1299
1300LogicalResult
1301spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1302 if (operands.size() != 2)
1303 return emitError(loc: unknownLoc,
1304 message: "OpTypeForwardPointer instruction must have two operands");
1305
1306 typeForwardPointerIDs.insert(X: operands[0]);
1307 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1308 // instruction that defines the actual type.
1309
1310 return success();
1311}
1312
1313LogicalResult
1314spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1315 // TODO: Add support for Access Qualifier.
1316 if (operands.size() != 8)
1317 return emitError(
1318 loc: unknownLoc,
1319 message: "OpTypeImage with non-eight operands are not supported yet");
1320
1321 Type elementTy = getType(id: operands[1]);
1322 if (!elementTy)
1323 return emitError(loc: unknownLoc, message: "OpTypeImage references undefined <id>: ")
1324 << operands[1];
1325
1326 auto dim = spirv::symbolizeDim(operands[2]);
1327 if (!dim)
1328 return emitError(loc: unknownLoc, message: "unknown Dim for OpTypeImage: ")
1329 << operands[2];
1330
1331 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1332 if (!depthInfo)
1333 return emitError(loc: unknownLoc, message: "unknown Depth for OpTypeImage: ")
1334 << operands[3];
1335
1336 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1337 if (!arrayedInfo)
1338 return emitError(loc: unknownLoc, message: "unknown Arrayed for OpTypeImage: ")
1339 << operands[4];
1340
1341 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1342 if (!samplingInfo)
1343 return emitError(loc: unknownLoc, message: "unknown MS for OpTypeImage: ") << operands[5];
1344
1345 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1346 if (!samplerUseInfo)
1347 return emitError(loc: unknownLoc, message: "unknown Sampled for OpTypeImage: ")
1348 << operands[6];
1349
1350 auto format = spirv::symbolizeImageFormat(operands[7]);
1351 if (!format)
1352 return emitError(loc: unknownLoc, message: "unknown Format for OpTypeImage: ")
1353 << operands[7];
1354
1355 typeMap[operands[0]] = spirv::ImageType::get(
1356 elementType: elementTy, dim: dim.value(), depth: depthInfo.value(), arrayed: arrayedInfo.value(),
1357 samplingInfo: samplingInfo.value(), samplerUse: samplerUseInfo.value(), format: format.value());
1358 return success();
1359}
1360
1361LogicalResult
1362spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1363 if (operands.size() != 2)
1364 return emitError(loc: unknownLoc, message: "OpTypeSampledImage must have two operands");
1365
1366 Type elementTy = getType(id: operands[1]);
1367 if (!elementTy)
1368 return emitError(loc: unknownLoc,
1369 message: "OpTypeSampledImage references undefined <id>: ")
1370 << operands[1];
1371
1372 typeMap[operands[0]] = spirv::SampledImageType::get(imageType: elementTy);
1373 return success();
1374}
1375
1376//===----------------------------------------------------------------------===//
1377// Constant
1378//===----------------------------------------------------------------------===//
1379
1380LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1381 bool isSpec) {
1382 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1383
1384 if (operands.size() < 2) {
1385 return emitError(loc: unknownLoc)
1386 << opname << " must have type <id> and result <id>";
1387 }
1388 if (operands.size() < 3) {
1389 return emitError(loc: unknownLoc)
1390 << opname << " must have at least 1 more parameter";
1391 }
1392
1393 Type resultType = getType(id: operands[0]);
1394 if (!resultType) {
1395 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1396 << operands[0];
1397 }
1398
1399 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1400 if (bitwidth == 64) {
1401 if (operands.size() == 4) {
1402 return success();
1403 }
1404 return emitError(loc: unknownLoc)
1405 << opname << " should have 2 parameters for 64-bit values";
1406 }
1407 if (bitwidth <= 32) {
1408 if (operands.size() == 3) {
1409 return success();
1410 }
1411
1412 return emitError(loc: unknownLoc)
1413 << opname
1414 << " should have 1 parameter for values with no more than 32 bits";
1415 }
1416 return emitError(loc: unknownLoc, message: "unsupported OpConstant bitwidth: ")
1417 << bitwidth;
1418 };
1419
1420 auto resultID = operands[1];
1421
1422 if (auto intType = dyn_cast<IntegerType>(Val&: resultType)) {
1423 auto bitwidth = intType.getWidth();
1424 if (failed(Result: checkOperandSizeForBitwidth(bitwidth))) {
1425 return failure();
1426 }
1427
1428 APInt value;
1429 if (bitwidth == 64) {
1430 // 64-bit integers are represented with two SPIR-V words. According to
1431 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1432 // literal’s low-order words appear first."
1433 struct DoubleWord {
1434 uint32_t word1;
1435 uint32_t word2;
1436 } words = {.word1: operands[2], .word2: operands[3]};
1437 value = APInt(64, llvm::bit_cast<uint64_t>(from: words), /*isSigned=*/true);
1438 } else if (bitwidth <= 32) {
1439 value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1440 /*implicitTrunc=*/true);
1441 }
1442
1443 auto attr = opBuilder.getIntegerAttr(type: intType, value);
1444
1445 if (isSpec) {
1446 createSpecConstant(loc: unknownLoc, resultID, defaultValue: attr);
1447 } else {
1448 // For normal constants, we just record the attribute (and its type) for
1449 // later materialization at use sites.
1450 constantMap.try_emplace(Key: resultID, Args&: attr, Args&: intType);
1451 }
1452
1453 return success();
1454 }
1455
1456 if (auto floatType = dyn_cast<FloatType>(Val&: resultType)) {
1457 auto bitwidth = floatType.getWidth();
1458 if (failed(Result: checkOperandSizeForBitwidth(bitwidth))) {
1459 return failure();
1460 }
1461
1462 APFloat value(0.f);
1463 if (floatType.isF64()) {
1464 // Double values are represented with two SPIR-V words. According to
1465 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1466 // literal’s low-order words appear first."
1467 struct DoubleWord {
1468 uint32_t word1;
1469 uint32_t word2;
1470 } words = {.word1: operands[2], .word2: operands[3]};
1471 value = APFloat(llvm::bit_cast<double>(from: words));
1472 } else if (floatType.isF32()) {
1473 value = APFloat(llvm::bit_cast<float>(from: operands[2]));
1474 } else if (floatType.isF16()) {
1475 APInt data(16, operands[2]);
1476 value = APFloat(APFloat::IEEEhalf(), data);
1477 } else if (floatType.isBF16()) {
1478 APInt data(16, operands[2]);
1479 value = APFloat(APFloat::BFloat(), data);
1480 }
1481
1482 auto attr = opBuilder.getFloatAttr(type: floatType, value);
1483 if (isSpec) {
1484 createSpecConstant(loc: unknownLoc, resultID, defaultValue: attr);
1485 } else {
1486 // For normal constants, we just record the attribute (and its type) for
1487 // later materialization at use sites.
1488 constantMap.try_emplace(Key: resultID, Args&: attr, Args&: floatType);
1489 }
1490
1491 return success();
1492 }
1493
1494 return emitError(loc: unknownLoc, message: "OpConstant can only generate values of "
1495 "scalar integer or floating-point type");
1496}
1497
1498LogicalResult spirv::Deserializer::processConstantBool(
1499 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1500 if (operands.size() != 2) {
1501 return emitError(loc: unknownLoc, message: "Op")
1502 << (isSpec ? "Spec" : "") << "Constant"
1503 << (isTrue ? "True" : "False")
1504 << " must have type <id> and result <id>";
1505 }
1506
1507 auto attr = opBuilder.getBoolAttr(value: isTrue);
1508 auto resultID = operands[1];
1509 if (isSpec) {
1510 createSpecConstant(loc: unknownLoc, resultID, defaultValue: attr);
1511 } else {
1512 // For normal constants, we just record the attribute (and its type) for
1513 // later materialization at use sites.
1514 constantMap.try_emplace(Key: resultID, Args&: attr, Args: opBuilder.getI1Type());
1515 }
1516
1517 return success();
1518}
1519
1520LogicalResult
1521spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1522 if (operands.size() < 2) {
1523 return emitError(loc: unknownLoc,
1524 message: "OpConstantComposite must have type <id> and result <id>");
1525 }
1526 if (operands.size() < 3) {
1527 return emitError(loc: unknownLoc,
1528 message: "OpConstantComposite must have at least 1 parameter");
1529 }
1530
1531 Type resultType = getType(id: operands[0]);
1532 if (!resultType) {
1533 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1534 << operands[0];
1535 }
1536
1537 SmallVector<Attribute, 4> elements;
1538 elements.reserve(N: operands.size() - 2);
1539 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1540 auto elementInfo = getConstant(id: operands[i]);
1541 if (!elementInfo) {
1542 return emitError(loc: unknownLoc, message: "OpConstantComposite component <id> ")
1543 << operands[i] << " must come from a normal constant";
1544 }
1545 elements.push_back(Elt: elementInfo->first);
1546 }
1547
1548 auto resultID = operands[1];
1549 if (auto shapedType = dyn_cast<ShapedType>(Val&: resultType)) {
1550 auto attr = DenseElementsAttr::get(type: shapedType, values: elements);
1551 // For normal constants, we just record the attribute (and its type) for
1552 // later materialization at use sites.
1553 constantMap.try_emplace(Key: resultID, Args&: attr, Args&: shapedType);
1554 } else if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: resultType)) {
1555 auto attr = opBuilder.getArrayAttr(value: elements);
1556 constantMap.try_emplace(Key: resultID, Args&: attr, Args&: resultType);
1557 } else {
1558 return emitError(loc: unknownLoc, message: "unsupported OpConstantComposite type: ")
1559 << resultType;
1560 }
1561
1562 return success();
1563}
1564
1565LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
1566 ArrayRef<uint32_t> operands) {
1567 if (operands.size() != 3) {
1568 return emitError(
1569 loc: unknownLoc,
1570 message: "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1571 << operands.size();
1572 }
1573
1574 Type resultType = getType(id: operands[0]);
1575 if (!resultType) {
1576 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1577 << operands[0];
1578 }
1579
1580 auto compositeType = dyn_cast<CompositeType>(Val&: resultType);
1581 if (!compositeType) {
1582 return emitError(loc: unknownLoc,
1583 message: "result type from <id> is not a composite type")
1584 << operands[0];
1585 }
1586
1587 uint32_t resultID = operands[1];
1588 uint32_t constantID = operands[2];
1589
1590 std::optional<std::pair<Attribute, Type>> constantInfo =
1591 getConstant(id: constantID);
1592 if (constantInfo.has_value()) {
1593 constantCompositeReplicateMap.try_emplace(
1594 Key: resultID, Args&: constantInfo.value().first, Args&: resultType);
1595 return success();
1596 }
1597
1598 std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1599 getConstantCompositeReplicate(id: constantID);
1600 if (replicatedConstantCompositeInfo.has_value()) {
1601 constantCompositeReplicateMap.try_emplace(
1602 Key: resultID, Args&: replicatedConstantCompositeInfo.value().first, Args&: resultType);
1603 return success();
1604 }
1605
1606 return emitError(loc: unknownLoc, message: "OpConstantCompositeReplicateEXT operand <id> ")
1607 << constantID
1608 << " must come from a normal constant or a "
1609 "OpConstantCompositeReplicateEXT";
1610}
1611
1612LogicalResult
1613spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1614 if (operands.size() < 2) {
1615 return emitError(
1616 loc: unknownLoc,
1617 message: "OpSpecConstantComposite must have type <id> and result <id>");
1618 }
1619 if (operands.size() < 3) {
1620 return emitError(loc: unknownLoc,
1621 message: "OpSpecConstantComposite must have at least 1 parameter");
1622 }
1623
1624 Type resultType = getType(id: operands[0]);
1625 if (!resultType) {
1626 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1627 << operands[0];
1628 }
1629
1630 auto resultID = operands[1];
1631 auto symName = opBuilder.getStringAttr(bytes: getSpecConstantSymbol(id: resultID));
1632
1633 SmallVector<Attribute, 4> elements;
1634 elements.reserve(N: operands.size() - 2);
1635 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1636 auto elementInfo = getSpecConstant(id: operands[i]);
1637 elements.push_back(Elt: SymbolRefAttr::get(symbol: elementInfo));
1638 }
1639
1640 auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1641 location: unknownLoc, args: TypeAttr::get(type: resultType), args&: symName,
1642 args: opBuilder.getArrayAttr(value: elements));
1643 specConstCompositeMap[resultID] = op;
1644
1645 return success();
1646}
1647
1648LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
1649 ArrayRef<uint32_t> operands) {
1650 if (operands.size() != 3) {
1651 return emitError(loc: unknownLoc, message: "OpSpecConstantCompositeReplicateEXT expects "
1652 "3 operands but found ")
1653 << operands.size();
1654 }
1655
1656 Type resultType = getType(id: operands[0]);
1657 if (!resultType) {
1658 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1659 << operands[0];
1660 }
1661
1662 auto compositeType = dyn_cast<CompositeType>(Val&: resultType);
1663 if (!compositeType) {
1664 return emitError(loc: unknownLoc,
1665 message: "result type from <id> is not a composite type")
1666 << operands[0];
1667 }
1668
1669 uint32_t resultID = operands[1];
1670
1671 auto symName = opBuilder.getStringAttr(bytes: getSpecConstantSymbol(id: resultID));
1672 spirv::SpecConstantOp constituentSpecConstantOp =
1673 getSpecConstant(id: operands[2]);
1674 auto op = opBuilder.create<spirv::EXTSpecConstantCompositeReplicateOp>(
1675 location: unknownLoc, args: TypeAttr::get(type: resultType), args&: symName,
1676 args: SymbolRefAttr::get(symbol: constituentSpecConstantOp));
1677
1678 specConstCompositeReplicateMap[resultID] = op;
1679
1680 return success();
1681}
1682
1683LogicalResult
1684spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1685 if (operands.size() < 3)
1686 return emitError(loc: unknownLoc, message: "OpConstantOperation must have type <id>, "
1687 "result <id>, and operand opcode");
1688
1689 uint32_t resultTypeID = operands[0];
1690
1691 if (!getType(id: resultTypeID))
1692 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1693 << resultTypeID;
1694
1695 uint32_t resultID = operands[1];
1696 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1697 auto emplaceResult = specConstOperationMap.try_emplace(
1698 Key: resultID,
1699 Args: SpecConstOperationMaterializationInfo{
1700 .enclodesOpcode: enclosedOpcode, .resultTypeID: resultTypeID,
1701 .enclosedOpOperands: SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1702
1703 if (!emplaceResult.second)
1704 return emitError(loc: unknownLoc, message: "value with <id>: ")
1705 << resultID << " is probably defined before.";
1706
1707 return success();
1708}
1709
1710Value spirv::Deserializer::materializeSpecConstantOperation(
1711 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1712 ArrayRef<uint32_t> enclosedOpOperands) {
1713
1714 Type resultType = getType(id: resultTypeID);
1715
1716 // Instructions wrapped by OpSpecConstantOp need an ID for their
1717 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1718 // dialect wrapped op. For that purpose, a new value map is created and "fake"
1719 // ID in that map is assigned to the result of the enclosed instruction. Note
1720 // that there is no need to update this fake ID since we only need to
1721 // reference the created Value for the enclosed op from the spv::YieldOp
1722 // created later in this method (both of which are the only values in their
1723 // region: the SpecConstantOperation's region). If we encounter another
1724 // SpecConstantOperation in the module, we simply re-use the fake ID since the
1725 // previous Value assigned to it isn't visible in the current scope anyway.
1726 DenseMap<uint32_t, Value> newValueMap;
1727 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1728 constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1729
1730 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1731 enclosedOpResultTypeAndOperands.push_back(Elt: resultTypeID);
1732 enclosedOpResultTypeAndOperands.push_back(Elt: fakeID);
1733 enclosedOpResultTypeAndOperands.append(in_start: enclosedOpOperands.begin(),
1734 in_end: enclosedOpOperands.end());
1735
1736 // Process enclosed instruction before creating the enclosing
1737 // specConstantOperation (and its region). This way, references to constants,
1738 // global variables, and spec constants will be materialized outside the new
1739 // op's region. For more info, see Deserializer::getValue's implementation.
1740 if (failed(
1741 Result: processInstruction(opcode: enclosedOpcode, operands: enclosedOpResultTypeAndOperands)))
1742 return Value();
1743
1744 // Since the enclosed op is emitted in the current block, split it in a
1745 // separate new block.
1746 Block *enclosedBlock = curBlock->splitBlock(splitBeforeOp: &curBlock->back());
1747
1748 auto loc = createFileLineColLoc(opBuilder);
1749 auto specConstOperationOp =
1750 opBuilder.create<spirv::SpecConstantOperationOp>(location: loc, args&: resultType);
1751
1752 Region &body = specConstOperationOp.getBody();
1753 // Move the new block into SpecConstantOperation's body.
1754 body.getBlocks().splice(where: body.end(), L2&: curBlock->getParent()->getBlocks(),
1755 first: Region::iterator(enclosedBlock));
1756 Block &block = body.back();
1757
1758 // RAII guard to reset the insertion point to the module's region after
1759 // deserializing the body of the specConstantOperation.
1760 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1761 opBuilder.setInsertionPointToEnd(&block);
1762
1763 opBuilder.create<spirv::YieldOp>(location: loc, args: block.front().getResult(idx: 0));
1764 return specConstOperationOp.getResult();
1765}
1766
1767LogicalResult
1768spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1769 if (operands.size() != 2) {
1770 return emitError(loc: unknownLoc,
1771 message: "OpConstantNull must have type <id> and result <id>");
1772 }
1773
1774 Type resultType = getType(id: operands[0]);
1775 if (!resultType) {
1776 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1777 << operands[0];
1778 }
1779
1780 auto resultID = operands[1];
1781 if (resultType.isIntOrFloat() || isa<VectorType>(Val: resultType)) {
1782 auto attr = opBuilder.getZeroAttr(type: resultType);
1783 // For normal constants, we just record the attribute (and its type) for
1784 // later materialization at use sites.
1785 constantMap.try_emplace(Key: resultID, Args&: attr, Args&: resultType);
1786 return success();
1787 }
1788
1789 return emitError(loc: unknownLoc, message: "unsupported OpConstantNull type: ")
1790 << resultType;
1791}
1792
1793//===----------------------------------------------------------------------===//
1794// Control flow
1795//===----------------------------------------------------------------------===//
1796
1797Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1798 if (auto *block = getBlock(id)) {
1799 LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1800 << " @ " << block << "\n");
1801 return block;
1802 }
1803
1804 // We don't know where this block will be placed finally (in a
1805 // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1806 // function for now and sort out the proper place later.
1807 auto *block = curFunction->addBlock();
1808 LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1809 << " @ " << block << "\n");
1810 return blockMap[id] = block;
1811}
1812
1813LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1814 if (!curBlock) {
1815 return emitError(loc: unknownLoc, message: "OpBranch must appear inside a block");
1816 }
1817
1818 if (operands.size() != 1) {
1819 return emitError(loc: unknownLoc, message: "OpBranch must take exactly one target label");
1820 }
1821
1822 auto *target = getOrCreateBlock(id: operands[0]);
1823 auto loc = createFileLineColLoc(opBuilder);
1824 // The preceding instruction for the OpBranch instruction could be an
1825 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1826 // the same OpLine information.
1827 opBuilder.create<spirv::BranchOp>(location: loc, args&: target);
1828
1829 clearDebugLine();
1830 return success();
1831}
1832
1833LogicalResult
1834spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1835 if (!curBlock) {
1836 return emitError(loc: unknownLoc,
1837 message: "OpBranchConditional must appear inside a block");
1838 }
1839
1840 if (operands.size() != 3 && operands.size() != 5) {
1841 return emitError(loc: unknownLoc,
1842 message: "OpBranchConditional must have condition, true label, "
1843 "false label, and optionally two branch weights");
1844 }
1845
1846 auto condition = getValue(id: operands[0]);
1847 auto *trueBlock = getOrCreateBlock(id: operands[1]);
1848 auto *falseBlock = getOrCreateBlock(id: operands[2]);
1849
1850 std::optional<std::pair<uint32_t, uint32_t>> weights;
1851 if (operands.size() == 5) {
1852 weights = std::make_pair(x: operands[3], y: operands[4]);
1853 }
1854 // The preceding instruction for the OpBranchConditional instruction could be
1855 // an OpSelectionMerge instruction, in this case they will have the same
1856 // OpLine information.
1857 auto loc = createFileLineColLoc(opBuilder);
1858 opBuilder.create<spirv::BranchConditionalOp>(
1859 location: loc, args&: condition, args&: trueBlock,
1860 /*trueArguments=*/args: ArrayRef<Value>(), args&: falseBlock,
1861 /*falseArguments=*/args: ArrayRef<Value>(), args&: weights);
1862
1863 clearDebugLine();
1864 return success();
1865}
1866
1867LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1868 if (!curFunction) {
1869 return emitError(loc: unknownLoc, message: "OpLabel must appear inside a function");
1870 }
1871
1872 if (operands.size() != 1) {
1873 return emitError(loc: unknownLoc, message: "OpLabel should only have result <id>");
1874 }
1875
1876 auto labelID = operands[0];
1877 // We may have forward declared this block.
1878 auto *block = getOrCreateBlock(id: labelID);
1879 LLVM_DEBUG(logger.startLine()
1880 << "[block] populating block " << block << "\n");
1881 // If we have seen this block, make sure it was just a forward declaration.
1882 assert(block->empty() && "re-deserialize the same block!");
1883
1884 opBuilder.setInsertionPointToStart(block);
1885 blockMap[labelID] = curBlock = block;
1886
1887 return success();
1888}
1889
1890LogicalResult
1891spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1892 if (!curBlock) {
1893 return emitError(loc: unknownLoc, message: "OpSelectionMerge must appear in a block");
1894 }
1895
1896 if (operands.size() < 2) {
1897 return emitError(
1898 loc: unknownLoc,
1899 message: "OpSelectionMerge must specify merge target and selection control");
1900 }
1901
1902 auto *mergeBlock = getOrCreateBlock(id: operands[0]);
1903 auto loc = createFileLineColLoc(opBuilder);
1904 auto selectionControl = operands[1];
1905
1906 if (!blockMergeInfo.try_emplace(Key: curBlock, Args&: loc, Args&: selectionControl, Args&: mergeBlock)
1907 .second) {
1908 return emitError(
1909 loc: unknownLoc,
1910 message: "a block cannot have more than one OpSelectionMerge instruction");
1911 }
1912
1913 return success();
1914}
1915
1916LogicalResult
1917spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1918 if (!curBlock) {
1919 return emitError(loc: unknownLoc, message: "OpLoopMerge must appear in a block");
1920 }
1921
1922 if (operands.size() < 3) {
1923 return emitError(loc: unknownLoc, message: "OpLoopMerge must specify merge target, "
1924 "continue target and loop control");
1925 }
1926
1927 auto *mergeBlock = getOrCreateBlock(id: operands[0]);
1928 auto *continueBlock = getOrCreateBlock(id: operands[1]);
1929 auto loc = createFileLineColLoc(opBuilder);
1930 uint32_t loopControl = operands[2];
1931
1932 if (!blockMergeInfo
1933 .try_emplace(Key: curBlock, Args&: loc, Args&: loopControl, Args&: mergeBlock, Args&: continueBlock)
1934 .second) {
1935 return emitError(
1936 loc: unknownLoc,
1937 message: "a block cannot have more than one OpLoopMerge instruction");
1938 }
1939
1940 return success();
1941}
1942
1943LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1944 if (!curBlock) {
1945 return emitError(loc: unknownLoc, message: "OpPhi must appear in a block");
1946 }
1947
1948 if (operands.size() < 4) {
1949 return emitError(loc: unknownLoc, message: "OpPhi must specify result type, result <id>, "
1950 "and variable-parent pairs");
1951 }
1952
1953 // Create a block argument for this OpPhi instruction.
1954 Type blockArgType = getType(id: operands[0]);
1955 BlockArgument blockArg = curBlock->addArgument(type: blockArgType, loc: unknownLoc);
1956 valueMap[operands[1]] = blockArg;
1957 LLVM_DEBUG(logger.startLine()
1958 << "[phi] created block argument " << blockArg
1959 << " id = " << operands[1] << " of type " << blockArgType << "\n");
1960
1961 // For each (value, predecessor) pair, insert the value to the predecessor's
1962 // blockPhiInfo entry so later we can fix the block argument there.
1963 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1964 uint32_t value = operands[i];
1965 Block *predecessor = getOrCreateBlock(id: operands[i + 1]);
1966 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1967 blockPhiInfo[predecessorTargetPair].push_back(Elt: value);
1968 LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1969 << " with arg id = " << value << "\n");
1970 }
1971
1972 return success();
1973}
1974
1975namespace {
1976/// A class for putting all blocks in a structured selection/loop in a
1977/// spirv.mlir.selection/spirv.mlir.loop op.
1978class ControlFlowStructurizer {
1979public:
1980#ifndef NDEBUG
1981 ControlFlowStructurizer(Location loc, uint32_t control,
1982 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1983 Block *merge, Block *cont,
1984 llvm::ScopedPrinter &logger)
1985 : location(loc), control(control), blockMergeInfo(mergeInfo),
1986 headerBlock(header), mergeBlock(merge), continueBlock(cont),
1987 logger(logger) {}
1988#else
1989 ControlFlowStructurizer(Location loc, uint32_t control,
1990 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1991 Block *merge, Block *cont)
1992 : location(loc), control(control), blockMergeInfo(mergeInfo),
1993 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1994#endif
1995
1996 /// Structurizes the loop at the given `headerBlock`.
1997 ///
1998 /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1999 /// all blocks in the structured loop into the spirv.mlir.loop's region. All
2000 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
2001 /// method will also update `mergeInfo` by remapping all blocks inside to the
2002 /// newly cloned ones inside structured control flow op's regions.
2003 LogicalResult structurize();
2004
2005private:
2006 /// Creates a new spirv.mlir.selection op at the beginning of the
2007 /// `mergeBlock`.
2008 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2009
2010 /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
2011 spirv::LoopOp createLoopOp(uint32_t loopControl);
2012
2013 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
2014 void collectBlocksInConstruct();
2015
2016 Location location;
2017 uint32_t control;
2018
2019 spirv::BlockMergeInfoMap &blockMergeInfo;
2020
2021 Block *headerBlock;
2022 Block *mergeBlock;
2023 Block *continueBlock; // nullptr for spirv.mlir.selection
2024
2025 SetVector<Block *> constructBlocks;
2026
2027#ifndef NDEBUG
2028 /// A logger used to emit information during the deserialzation process.
2029 llvm::ScopedPrinter &logger;
2030#endif
2031};
2032} // namespace
2033
2034spirv::SelectionOp
2035ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2036 // Create a builder and set the insertion point to the beginning of the
2037 // merge block so that the newly created SelectionOp will be inserted there.
2038 OpBuilder builder(&mergeBlock->front());
2039
2040 auto control = static_cast<spirv::SelectionControl>(selectionControl);
2041 auto selectionOp = builder.create<spirv::SelectionOp>(location, args&: control);
2042 selectionOp.addMergeBlock(builder);
2043
2044 return selectionOp;
2045}
2046
2047spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2048 // Create a builder and set the insertion point to the beginning of the
2049 // merge block so that the newly created LoopOp will be inserted there.
2050 OpBuilder builder(&mergeBlock->front());
2051
2052 auto control = static_cast<spirv::LoopControl>(loopControl);
2053 auto loopOp = builder.create<spirv::LoopOp>(location, args&: control);
2054 loopOp.addEntryAndMergeBlock(builder);
2055
2056 return loopOp;
2057}
2058
2059void ControlFlowStructurizer::collectBlocksInConstruct() {
2060 assert(constructBlocks.empty() && "expected empty constructBlocks");
2061
2062 // Put the header block in the work list first.
2063 constructBlocks.insert(X: headerBlock);
2064
2065 // For each item in the work list, add its successors excluding the merge
2066 // block.
2067 for (unsigned i = 0; i < constructBlocks.size(); ++i) {
2068 for (auto *successor : constructBlocks[i]->getSuccessors())
2069 if (successor != mergeBlock)
2070 constructBlocks.insert(X: successor);
2071 }
2072}
2073
2074LogicalResult ControlFlowStructurizer::structurize() {
2075 Operation *op = nullptr;
2076 bool isLoop = continueBlock != nullptr;
2077 if (isLoop) {
2078 if (auto loopOp = createLoopOp(loopControl: control))
2079 op = loopOp.getOperation();
2080 } else {
2081 if (auto selectionOp = createSelectionOp(selectionControl: control))
2082 op = selectionOp.getOperation();
2083 }
2084 if (!op)
2085 return failure();
2086 Region &body = op->getRegion(index: 0);
2087
2088 IRMapping mapper;
2089 // All references to the old merge block should be directed to the
2090 // selection/loop merge block in the SelectionOp/LoopOp's region.
2091 mapper.map(from: mergeBlock, to: &body.back());
2092
2093 collectBlocksInConstruct();
2094
2095 // We've identified all blocks belonging to the selection/loop's region. Now
2096 // need to "move" them into the selection/loop. Instead of really moving the
2097 // blocks, in the following we copy them and remap all values and branches.
2098 // This is because:
2099 // * Inserting a block into a region requires the block not in any region
2100 // before. But selections/loops can nest so we can create selection/loop ops
2101 // in a nested manner, which means some blocks may already be in a
2102 // selection/loop region when to be moved again.
2103 // * It's much trickier to fix up the branches into and out of the loop's
2104 // region: we need to treat not-moved blocks and moved blocks differently:
2105 // Not-moved blocks jumping to the loop header block need to jump to the
2106 // merge point containing the new loop op but not the loop continue block's
2107 // back edge. Moved blocks jumping out of the loop need to jump to the
2108 // merge block inside the loop region but not other not-moved blocks.
2109 // We cannot use replaceAllUsesWith clearly and it's harder to follow the
2110 // logic.
2111
2112 // Create a corresponding block in the SelectionOp/LoopOp's region for each
2113 // block in this loop construct.
2114 OpBuilder builder(body);
2115 for (auto *block : constructBlocks) {
2116 // Create a block and insert it before the selection/loop merge block in the
2117 // SelectionOp/LoopOp's region.
2118 auto *newBlock = builder.createBlock(insertBefore: &body.back());
2119 mapper.map(from: block, to: newBlock);
2120 LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
2121 << " from block " << block << "\n");
2122 if (!isFnEntryBlock(block)) {
2123 for (BlockArgument blockArg : block->getArguments()) {
2124 auto newArg =
2125 newBlock->addArgument(type: blockArg.getType(), loc: blockArg.getLoc());
2126 mapper.map(from: blockArg, to: newArg);
2127 LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
2128 << blockArg << " to " << newArg << "\n");
2129 }
2130 } else {
2131 LLVM_DEBUG(logger.startLine()
2132 << "[cf] block " << block << " is a function entry block\n");
2133 }
2134
2135 for (auto &op : *block)
2136 newBlock->push_back(op: op.clone(mapper));
2137 }
2138
2139 // Go through all ops and remap the operands.
2140 auto remapOperands = [&](Operation *op) {
2141 for (auto &operand : op->getOpOperands())
2142 if (Value mappedOp = mapper.lookupOrNull(from: operand.get()))
2143 operand.set(mappedOp);
2144 for (auto &succOp : op->getBlockOperands())
2145 if (Block *mappedOp = mapper.lookupOrNull(from: succOp.get()))
2146 succOp.set(mappedOp);
2147 };
2148 for (auto &block : body)
2149 block.walk(callback&: remapOperands);
2150
2151 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
2152 // the selection/loop construct into its region. Next we need to fix the
2153 // connections between this new SelectionOp/LoopOp with existing blocks.
2154
2155 // All existing incoming branches should go to the merge block, where the
2156 // SelectionOp/LoopOp resides right now.
2157 headerBlock->replaceAllUsesWith(newValue&: mergeBlock);
2158
2159 LLVM_DEBUG({
2160 logger.startLine() << "[cf] after cloning and fixing references:\n";
2161 headerBlock->getParentOp()->print(logger.getOStream());
2162 logger.startLine() << "\n";
2163 });
2164
2165 if (isLoop) {
2166 if (!mergeBlock->args_empty()) {
2167 return mergeBlock->getParentOp()->emitError(
2168 message: "OpPhi in loop merge block unsupported");
2169 }
2170
2171 // The loop header block may have block arguments. Since now we place the
2172 // loop op inside the old merge block, we need to make sure the old merge
2173 // block has the same block argument list.
2174 for (BlockArgument blockArg : headerBlock->getArguments())
2175 mergeBlock->addArgument(type: blockArg.getType(), loc: blockArg.getLoc());
2176
2177 // If the loop header block has block arguments, make sure the spirv.Branch
2178 // op matches.
2179 SmallVector<Value, 4> blockArgs;
2180 if (!headerBlock->args_empty())
2181 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2182
2183 // The loop entry block should have a unconditional branch jumping to the
2184 // loop header block.
2185 builder.setInsertionPointToEnd(&body.front());
2186 builder.create<spirv::BranchOp>(location, args: mapper.lookupOrNull(from: headerBlock),
2187 args: ArrayRef<Value>(blockArgs));
2188 }
2189
2190 // Values defined inside the selection region that need to be yielded outside
2191 // the region.
2192 SmallVector<Value> valuesToYield;
2193 // Outside uses of values that were sunk into the selection region. Those uses
2194 // will be replaced with values returned by the SelectionOp.
2195 SmallVector<Value> outsideUses;
2196
2197 // Move block arguments of the original block (`mergeBlock`) into the merge
2198 // block inside the selection (`body.back()`). Values produced by block
2199 // arguments will be yielded by the selection region. We do not update uses or
2200 // erase original block arguments yet. It will be done later in the code.
2201 //
2202 // Code below is not executed for loops as it would interfere with the logic
2203 // above. Currently block arguments in the merge block are not supported, but
2204 // instead, the code above copies those arguments from the header block into
2205 // the merge block. As such, running the code would yield those copied
2206 // arguments that is most likely not a desired behaviour. This may need to be
2207 // revisited in the future.
2208 if (!isLoop)
2209 for (BlockArgument blockArg : mergeBlock->getArguments()) {
2210 // Create new block arguments in the last block ("merge block") of the
2211 // selection region. We create one argument for each argument in
2212 // `mergeBlock`. This new value will need to be yielded, and the original
2213 // value replaced, so add them to appropriate vectors.
2214 body.back().addArgument(type: blockArg.getType(), loc: blockArg.getLoc());
2215 valuesToYield.push_back(Elt: body.back().getArguments().back());
2216 outsideUses.push_back(Elt: blockArg);
2217 }
2218
2219 // All the blocks cloned into the SelectionOp/LoopOp's region can now be
2220 // cleaned up.
2221 LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
2222 // First we need to drop all operands' references inside all blocks. This is
2223 // needed because we can have blocks referencing SSA values from one another.
2224 for (auto *block : constructBlocks)
2225 block->dropAllReferences();
2226
2227 // All internal uses should be removed from original blocks by now, so
2228 // whatever is left is an outside use and will need to be yielded from
2229 // the newly created selection / loop region.
2230 for (Block *block : constructBlocks) {
2231 for (Operation &op : *block) {
2232 if (!op.use_empty())
2233 for (Value result : op.getResults()) {
2234 valuesToYield.push_back(Elt: mapper.lookupOrNull(from: result));
2235 outsideUses.push_back(Elt: result);
2236 }
2237 }
2238 for (BlockArgument &arg : block->getArguments()) {
2239 if (!arg.use_empty()) {
2240 valuesToYield.push_back(Elt: mapper.lookupOrNull(from: arg));
2241 outsideUses.push_back(Elt: arg);
2242 }
2243 }
2244 }
2245
2246 assert(valuesToYield.size() == outsideUses.size());
2247
2248 // If we need to yield any values from the selection / loop region we will
2249 // take care of it here.
2250 if (!valuesToYield.empty()) {
2251 LLVM_DEBUG(logger.startLine()
2252 << "[cf] yielding values from the selection / loop region\n");
2253
2254 // Update `mlir.merge` with values to be yield.
2255 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2256 Operation *merge = llvm::getSingleElement(C&: mergeOps);
2257 assert(merge);
2258 merge->setOperands(valuesToYield);
2259
2260 // MLIR does not allow changing the number of results of an operation, so
2261 // we create a new SelectionOp / LoopOp with required list of results and
2262 // move the region from the initial SelectionOp / LoopOp. The initial
2263 // operation is then removed. Since we move the region to the new op all
2264 // links between blocks and remapping we have previously done should be
2265 // preserved.
2266 builder.setInsertionPoint(&mergeBlock->front());
2267
2268 Operation *newOp = nullptr;
2269
2270 if (isLoop)
2271 newOp = builder.create<spirv::LoopOp>(
2272 location, args: TypeRange(ValueRange(outsideUses)),
2273 args: static_cast<spirv::LoopControl>(control));
2274 else
2275 newOp = builder.create<spirv::SelectionOp>(
2276 location, args: TypeRange(ValueRange(outsideUses)),
2277 args: static_cast<spirv::SelectionControl>(control));
2278
2279 newOp->getRegion(index: 0).takeBody(other&: body);
2280
2281 // Remove initial op and swap the pointer to the newly created one.
2282 op->erase();
2283 op = newOp;
2284
2285 // Update all outside uses to use results of the SelectionOp / LoopOp and
2286 // remove block arguments from the original merge block.
2287 for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2288 outsideUses[i].replaceAllUsesWith(newValue: op->getResult(idx: i));
2289
2290 // We do not support block arguments in loop merge block. Also running this
2291 // function with loop would break some of the loop specific code above
2292 // dealing with block arguments.
2293 if (!isLoop)
2294 mergeBlock->eraseArguments(start: 0, num: mergeBlock->getNumArguments());
2295 }
2296
2297 // Check that whether some op in the to-be-erased blocks still has uses. Those
2298 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2299 // region. We cannot handle such cases given that once a value is sinked into
2300 // the SelectionOp/LoopOp's region, there is no escape for it.
2301 for (auto *block : constructBlocks) {
2302 for (Operation &op : *block)
2303 if (!op.use_empty())
2304 return op.emitOpError(message: "failed control flow structurization: value has "
2305 "uses outside of the "
2306 "enclosing selection/loop construct");
2307 for (BlockArgument &arg : block->getArguments())
2308 if (!arg.use_empty())
2309 return emitError(loc: arg.getLoc(), message: "failed control flow structurization: "
2310 "block argument has uses outside of the "
2311 "enclosing selection/loop construct");
2312 }
2313
2314 // Then erase all old blocks.
2315 for (auto *block : constructBlocks) {
2316 // We've cloned all blocks belonging to this construct into the structured
2317 // control flow op's region. Among these blocks, some may compose another
2318 // selection/loop. If so, they will be recorded within blockMergeInfo.
2319 // We need to update the pointers there to the newly remapped ones so we can
2320 // continue structurizing them later.
2321 //
2322 // We need to walk each block as constructBlocks do not include blocks
2323 // internal to ops already structured within those blocks. It is not
2324 // fully clear to me why the mergeInfo of blocks (yet to be structured)
2325 // inside already structured selections/loops get invalidated and needs
2326 // updating, however the following example code can cause a crash (depending
2327 // on the structuring order), when the most inner selection is being
2328 // structured after the outer selection and loop have been already
2329 // structured:
2330 //
2331 // spirv.mlir.for {
2332 // // ...
2333 // spirv.mlir.selection {
2334 // // ..
2335 // // A selection region that hasn't been yet structured!
2336 // // ..
2337 // }
2338 // // ...
2339 // }
2340 //
2341 // If the loop gets structured after the outer selection, but before the
2342 // inner selection. Moving the already structured selection inside the loop
2343 // will invalidate the mergeInfo of the region that is not yet structured.
2344 // Just going over constructBlocks will not check and updated header blocks
2345 // inside the already structured selection region. Walking block fixes that.
2346 //
2347 // TODO: If structuring was done in a fixed order starting with inner
2348 // most constructs this most likely not be an issue and the whole code
2349 // section could be removed. However, with the current non-deterministic
2350 // order this is not possible.
2351 //
2352 // TODO: The asserts in the following assumes input SPIR-V blob forms
2353 // correctly nested selection/loop constructs. We should relax this and
2354 // support error cases better.
2355 auto updateMergeInfo = [&](Block *block) -> WalkResult {
2356 auto it = blockMergeInfo.find(Val: block);
2357 if (it != blockMergeInfo.end()) {
2358 // Use the original location for nested selection/loop ops.
2359 Location loc = it->second.loc;
2360
2361 Block *newHeader = mapper.lookupOrNull(from: block);
2362 if (!newHeader)
2363 return emitError(loc, message: "failed control flow structurization: nested "
2364 "loop header block should be remapped!");
2365
2366 Block *newContinue = it->second.continueBlock;
2367 if (newContinue) {
2368 newContinue = mapper.lookupOrNull(from: newContinue);
2369 if (!newContinue)
2370 return emitError(loc, message: "failed control flow structurization: nested "
2371 "loop continue block should be remapped!");
2372 }
2373
2374 Block *newMerge = it->second.mergeBlock;
2375 if (Block *mappedTo = mapper.lookupOrNull(from: newMerge))
2376 newMerge = mappedTo;
2377
2378 // The iterator should be erased before adding a new entry into
2379 // blockMergeInfo to avoid iterator invalidation.
2380 blockMergeInfo.erase(I: it);
2381 blockMergeInfo.try_emplace(Key: newHeader, Args&: loc, Args&: it->second.control, Args&: newMerge,
2382 Args&: newContinue);
2383 }
2384
2385 return WalkResult::advance();
2386 };
2387
2388 if (block->walk(callback&: updateMergeInfo).wasInterrupted())
2389 return failure();
2390
2391 // The structured selection/loop's entry block does not have arguments.
2392 // If the function's header block is also part of the structured control
2393 // flow, we cannot just simply erase it because it may contain arguments
2394 // matching the function signature and used by the cloned blocks.
2395 if (isFnEntryBlock(block)) {
2396 LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2397 << " to only contain a spirv.Branch op\n");
2398 // Still keep the function entry block for the potential block arguments,
2399 // but replace all ops inside with a branch to the merge block.
2400 block->clear();
2401 builder.setInsertionPointToEnd(block);
2402 builder.create<spirv::BranchOp>(location, args&: mergeBlock);
2403 } else {
2404 LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2405 block->erase();
2406 }
2407 }
2408
2409 LLVM_DEBUG(logger.startLine()
2410 << "[cf] after structurizing construct with header block "
2411 << headerBlock << ":\n"
2412 << *op << "\n");
2413
2414 return success();
2415}
2416
2417LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2418 LLVM_DEBUG({
2419 logger.startLine()
2420 << "//----- [phi] start wiring up block arguments -----//\n";
2421 logger.indent();
2422 });
2423
2424 OpBuilder::InsertionGuard guard(opBuilder);
2425
2426 for (const auto &info : blockPhiInfo) {
2427 Block *block = info.first.first;
2428 Block *target = info.first.second;
2429 const BlockPhiInfo &phiInfo = info.second;
2430 LLVM_DEBUG({
2431 logger.startLine() << "[phi] block " << block << "\n";
2432 logger.startLine() << "[phi] before creating block argument:\n";
2433 block->getParentOp()->print(logger.getOStream());
2434 logger.startLine() << "\n";
2435 });
2436
2437 // Set insertion point to before this block's terminator early because we
2438 // may materialize ops via getValue() call.
2439 auto *op = block->getTerminator();
2440 opBuilder.setInsertionPoint(op);
2441
2442 SmallVector<Value, 4> blockArgs;
2443 blockArgs.reserve(N: phiInfo.size());
2444 for (uint32_t valueId : phiInfo) {
2445 if (Value value = getValue(id: valueId)) {
2446 blockArgs.push_back(Elt: value);
2447 LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2448 << " id = " << valueId << "\n");
2449 } else {
2450 return emitError(loc: unknownLoc, message: "OpPhi references undefined value!");
2451 }
2452 }
2453
2454 if (auto branchOp = dyn_cast<spirv::BranchOp>(Val: op)) {
2455 // Replace the previous branch op with a new one with block arguments.
2456 opBuilder.create<spirv::BranchOp>(location: branchOp.getLoc(), args: branchOp.getTarget(),
2457 args&: blockArgs);
2458 branchOp.erase();
2459 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(Val: op)) {
2460 assert((branchCondOp.getTrueBlock() == target ||
2461 branchCondOp.getFalseBlock() == target) &&
2462 "expected target to be either the true or false target");
2463 if (target == branchCondOp.getTrueTarget())
2464 opBuilder.create<spirv::BranchConditionalOp>(
2465 location: branchCondOp.getLoc(), args: branchCondOp.getCondition(), args&: blockArgs,
2466 args: branchCondOp.getFalseBlockArguments(),
2467 args: branchCondOp.getBranchWeightsAttr(), args: branchCondOp.getTrueTarget(),
2468 args: branchCondOp.getFalseTarget());
2469 else
2470 opBuilder.create<spirv::BranchConditionalOp>(
2471 location: branchCondOp.getLoc(), args: branchCondOp.getCondition(),
2472 args: branchCondOp.getTrueBlockArguments(), args&: blockArgs,
2473 args: branchCondOp.getBranchWeightsAttr(), args: branchCondOp.getTrueBlock(),
2474 args: branchCondOp.getFalseBlock());
2475
2476 branchCondOp.erase();
2477 } else {
2478 return emitError(loc: unknownLoc, message: "unimplemented terminator for Phi creation");
2479 }
2480
2481 LLVM_DEBUG({
2482 logger.startLine() << "[phi] after creating block argument:\n";
2483 block->getParentOp()->print(logger.getOStream());
2484 logger.startLine() << "\n";
2485 });
2486 }
2487 blockPhiInfo.clear();
2488
2489 LLVM_DEBUG({
2490 logger.unindent();
2491 logger.startLine()
2492 << "//--- [phi] completed wiring up block arguments ---//\n";
2493 });
2494 return success();
2495}
2496
2497LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2498 // Create a copy, so we can modify keys in the original.
2499 BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
2500 for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2501 it != e; ++it) {
2502 auto &[block, mergeInfo] = *it;
2503
2504 // Skip processing loop regions. For loop regions continueBlock is non-null.
2505 if (mergeInfo.continueBlock)
2506 continue;
2507
2508 if (!block->mightHaveTerminator())
2509 continue;
2510
2511 Operation *terminator = block->getTerminator();
2512 assert(terminator);
2513
2514 if (!isa<spirv::BranchConditionalOp>(Val: terminator))
2515 continue;
2516
2517 // Check if the current header block is a merge block of another construct.
2518 bool splitHeaderMergeBlock = false;
2519 for (const auto &[_, mergeInfo] : blockMergeInfo) {
2520 if (mergeInfo.mergeBlock == block)
2521 splitHeaderMergeBlock = true;
2522 }
2523
2524 // Do not split a block that only contains a conditional branch, unless it
2525 // is also a merge block of another construct - in that case we want to
2526 // split the block. We do not want two constructs to share header / merge
2527 // block.
2528 if (!llvm::hasSingleElement(C&: *block) || splitHeaderMergeBlock) {
2529 Block *newBlock = block->splitBlock(splitBeforeOp: terminator);
2530 OpBuilder builder(block, block->end());
2531 builder.create<spirv::BranchOp>(location: block->getParent()->getLoc(), args&: newBlock);
2532
2533 // After splitting we need to update the map to use the new block as a
2534 // header.
2535 blockMergeInfo.erase(Val: block);
2536 blockMergeInfo.try_emplace(Key: newBlock, Args&: mergeInfo);
2537 }
2538 }
2539
2540 return success();
2541}
2542
2543LogicalResult spirv::Deserializer::structurizeControlFlow() {
2544 if (!options.enableControlFlowStructurization) {
2545 LLVM_DEBUG(
2546 {
2547 logger.startLine()
2548 << "//----- [cf] skip structurizing control flow -----//\n";
2549 logger.indent();
2550 });
2551 return success();
2552 }
2553
2554 LLVM_DEBUG({
2555 logger.startLine()
2556 << "//----- [cf] start structurizing control flow -----//\n";
2557 logger.indent();
2558 });
2559
2560 LLVM_DEBUG({
2561 logger.startLine() << "[cf] split conditional blocks\n";
2562 logger.startLine() << "\n";
2563 });
2564
2565 if (failed(Result: splitConditionalBlocks())) {
2566 return failure();
2567 }
2568
2569 // TODO: This loop is non-deterministic. Iteration order may vary between runs
2570 // for the same shader as the key to the map is a pointer. See:
2571 // https://github.com/llvm/llvm-project/issues/128547
2572 while (!blockMergeInfo.empty()) {
2573 Block *headerBlock = blockMergeInfo.begin()->first;
2574 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2575
2576 LLVM_DEBUG({
2577 logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2578 headerBlock->print(logger.getOStream());
2579 logger.startLine() << "\n";
2580 });
2581
2582 auto *mergeBlock = mergeInfo.mergeBlock;
2583 assert(mergeBlock && "merge block cannot be nullptr");
2584 if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2585 return emitError(loc: unknownLoc, message: "OpPhi in loop merge block unimplemented");
2586 LLVM_DEBUG({
2587 logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2588 mergeBlock->print(logger.getOStream());
2589 logger.startLine() << "\n";
2590 });
2591
2592 auto *continueBlock = mergeInfo.continueBlock;
2593 LLVM_DEBUG(if (continueBlock) {
2594 logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2595 continueBlock->print(logger.getOStream());
2596 logger.startLine() << "\n";
2597 });
2598 // Erase this case before calling into structurizer, who will update
2599 // blockMergeInfo.
2600 blockMergeInfo.erase(I: blockMergeInfo.begin());
2601 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2602 blockMergeInfo, headerBlock,
2603 mergeBlock, continueBlock
2604#ifndef NDEBUG
2605 ,
2606 logger
2607#endif
2608 );
2609 if (failed(Result: structurizer.structurize()))
2610 return failure();
2611 }
2612
2613 LLVM_DEBUG({
2614 logger.unindent();
2615 logger.startLine()
2616 << "//--- [cf] completed structurizing control flow ---//\n";
2617 });
2618 return success();
2619}
2620
2621//===----------------------------------------------------------------------===//
2622// Debug
2623//===----------------------------------------------------------------------===//
2624
2625Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2626 if (!debugLine)
2627 return unknownLoc;
2628
2629 auto fileName = debugInfoMap.lookup(Val: debugLine->fileID).str();
2630 if (fileName.empty())
2631 fileName = "<unknown>";
2632 return FileLineColLoc::get(filename: opBuilder.getStringAttr(bytes: fileName), line: debugLine->line,
2633 column: debugLine->column);
2634}
2635
2636LogicalResult
2637spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2638 // According to SPIR-V spec:
2639 // "This location information applies to the instructions physically
2640 // following this instruction, up to the first occurrence of any of the
2641 // following: the next end of block, the next OpLine instruction, or the next
2642 // OpNoLine instruction."
2643 if (operands.size() != 3)
2644 return emitError(loc: unknownLoc, message: "OpLine must have 3 operands");
2645 debugLine = DebugLine{.fileID: operands[0], .line: operands[1], .column: operands[2]};
2646 return success();
2647}
2648
2649void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2650
2651LogicalResult
2652spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2653 if (operands.size() < 2)
2654 return emitError(loc: unknownLoc, message: "OpString needs at least 2 operands");
2655
2656 if (!debugInfoMap.lookup(Val: operands[0]).empty())
2657 return emitError(loc: unknownLoc,
2658 message: "duplicate debug string found for result <id> ")
2659 << operands[0];
2660
2661 unsigned wordIndex = 1;
2662 StringRef debugString = decodeStringLiteral(words: operands, wordIndex);
2663 if (wordIndex != operands.size())
2664 return emitError(loc: unknownLoc,
2665 message: "unexpected trailing words in OpString instruction");
2666
2667 debugInfoMap[operands[0]] = debugString;
2668 return success();
2669}
2670

source code of mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp