1//===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Deserializer.h"
14
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
22#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Sequence.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/raw_ostream.h"
31#include <optional>
32
33using namespace mlir;
34
35#define DEBUG_TYPE "spirv-deserialization"
36
37//===----------------------------------------------------------------------===//
38// Utility Functions
39//===----------------------------------------------------------------------===//
40
41/// Returns true if the given `block` is a function entry block.
42static inline bool isFnEntryBlock(Block *block) {
43 return block->isEntryBlock() &&
44 isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45}
46
47//===----------------------------------------------------------------------===//
48// Deserializer Method Definitions
49//===----------------------------------------------------------------------===//
50
51spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
52 MLIRContext *context,
53 const spirv::DeserializationOptions &options)
54 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()), options(options)
56#ifndef NDEBUG
57 ,
58 logger(llvm::dbgs())
59#endif
60{
61}
62
63LogicalResult spirv::Deserializer::deserialize() {
64 LLVM_DEBUG({
65 logger.resetIndent();
66 logger.startLine()
67 << "//+++---------- start deserialization ----------+++//\n";
68 });
69
70 if (failed(Result: processHeader()))
71 return failure();
72
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
77 // Slice the next instruction out and populate `opcode` and `operands`.
78 // Internally this also updates `curOffset`.
79 if (failed(sliceInstruction(opcode, operands)))
80 return failure();
81
82 if (failed(processInstruction(opcode, operands)))
83 return failure();
84 }
85
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
88
89 for (auto &deferred : deferredInstructions) {
90 if (failed(processInstruction(deferred.first, deferred.second, false))) {
91 return failure();
92 }
93 }
94
95 attachVCETriple();
96
97 LLVM_DEBUG(logger.startLine()
98 << "//+++-------- completed deserialization --------+++//\n");
99 return success();
100}
101
102OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
103 return std::move(module);
104}
105
106//===----------------------------------------------------------------------===//
107// Module structure
108//===----------------------------------------------------------------------===//
109
110OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111 OpBuilder builder(context);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
114 return cast<spirv::ModuleOp>(Operation::create(state));
115}
116
117LogicalResult spirv::Deserializer::processHeader() {
118 if (binary.size() < spirv::kHeaderWordCount)
119 return emitError(loc: unknownLoc,
120 message: "SPIR-V binary module must have a 5-word header");
121
122 if (binary[0] != spirv::kMagicNumber)
123 return emitError(loc: unknownLoc, message: "incorrect magic number");
124
125 // Version number bytes: 0 | major number | minor number | 0
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130#define MIN_VERSION_CASE(v) \
131 case v: \
132 version = spirv::Version::V_1_##v; \
133 break
134
135 MIN_VERSION_CASE(0);
136 MIN_VERSION_CASE(1);
137 MIN_VERSION_CASE(2);
138 MIN_VERSION_CASE(3);
139 MIN_VERSION_CASE(4);
140 MIN_VERSION_CASE(5);
141#undef MIN_VERSION_CASE
142 default:
143 return emitError(loc: unknownLoc, message: "unsupported SPIR-V minor version: ")
144 << minorVersion;
145 }
146 } else {
147 return emitError(loc: unknownLoc, message: "unsupported SPIR-V major version: ")
148 << majorVersion;
149 }
150
151 // TODO: generator number, bound, schema
152 curOffset = spirv::kHeaderWordCount;
153 return success();
154}
155
156LogicalResult
157spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
158 if (operands.size() != 1)
159 return emitError(loc: unknownLoc, message: "OpCapability must have one parameter");
160
161 auto cap = spirv::symbolizeCapability(operands[0]);
162 if (!cap)
163 return emitError(loc: unknownLoc, message: "unknown capability: ") << operands[0];
164
165 capabilities.insert(*cap);
166 return success();
167}
168
169LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
170 if (words.empty()) {
171 return emitError(
172 loc: unknownLoc,
173 message: "OpExtension must have a literal string for the extension name");
174 }
175
176 unsigned wordIndex = 0;
177 StringRef extName = decodeStringLiteral(words, wordIndex);
178 if (wordIndex != words.size())
179 return emitError(loc: unknownLoc,
180 message: "unexpected trailing words in OpExtension instruction");
181 auto ext = spirv::symbolizeExtension(extName);
182 if (!ext)
183 return emitError(loc: unknownLoc, message: "unknown extension: ") << extName;
184
185 extensions.insert(*ext);
186 return success();
187}
188
189LogicalResult
190spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
191 if (words.size() < 2) {
192 return emitError(loc: unknownLoc,
193 message: "OpExtInstImport must have a result <id> and a literal "
194 "string for the extended instruction set name");
195 }
196
197 unsigned wordIndex = 1;
198 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
199 if (wordIndex != words.size()) {
200 return emitError(loc: unknownLoc,
201 message: "unexpected trailing words in OpExtInstImport");
202 }
203 return success();
204}
205
206void spirv::Deserializer::attachVCETriple() {
207 (*module)->setAttr(
208 spirv::ModuleOp::getVCETripleAttrName(),
209 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
210 extensions.getArrayRef(), context));
211}
212
213LogicalResult
214spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
215 if (operands.size() != 2)
216 return emitError(loc: unknownLoc, message: "OpMemoryModel must have two operands");
217
218 (*module)->setAttr(
219 module->getAddressingModelAttrName(),
220 opBuilder.getAttr<spirv::AddressingModelAttr>(
221 static_cast<spirv::AddressingModel>(operands.front())));
222
223 (*module)->setAttr(module->getMemoryModelAttrName(),
224 opBuilder.getAttr<spirv::MemoryModelAttr>(
225 static_cast<spirv::MemoryModel>(operands.back())));
226
227 return success();
228}
229
230template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
231LogicalResult deserializeCacheControlDecoration(
232 Location loc, OpBuilder &opBuilder,
233 DenseMap<uint32_t, NamedAttrList> &decorations, ArrayRef<uint32_t> words,
234 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
235 if (words.size() != 4) {
236 return emitError(loc, message: "OpDecoration with ")
237 << decorationName << "needs a cache control integer literal and a "
238 << cacheControlKind << " cache control literal";
239 }
240 unsigned cacheLevel = words[2];
241 auto cacheControlAttr = static_cast<EnumTy>(words[3]);
242 auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
243 SmallVector<Attribute> attrs;
244 if (auto attrList =
245 llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
246 llvm::append_range(attrs, attrList);
247 attrs.push_back(Elt: value);
248 decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
249 return success();
250}
251
252LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
253 // TODO: This function should also be auto-generated. For now, since only a
254 // few decorations are processed/handled in a meaningful manner, going with a
255 // manual implementation.
256 if (words.size() < 2) {
257 return emitError(
258 loc: unknownLoc, message: "OpDecorate must have at least result <id> and Decoration");
259 }
260 auto decorationName =
261 stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
262 if (decorationName.empty()) {
263 return emitError(loc: unknownLoc, message: "invalid Decoration code : ") << words[1];
264 }
265 auto symbol = getSymbolDecoration(decorationName: decorationName);
266 switch (static_cast<spirv::Decoration>(words[1])) {
267 case spirv::Decoration::FPFastMathMode:
268 if (words.size() != 3) {
269 return emitError(loc: unknownLoc, message: "OpDecorate with ")
270 << decorationName << " needs a single integer literal";
271 }
272 decorations[words[0]].set(
273 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
274 static_cast<FPFastMathMode>(words[2])));
275 break;
276 case spirv::Decoration::FPRoundingMode:
277 if (words.size() != 3) {
278 return emitError(loc: unknownLoc, message: "OpDecorate with ")
279 << decorationName << " needs a single integer literal";
280 }
281 decorations[words[0]].set(
282 symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
283 static_cast<FPRoundingMode>(words[2])));
284 break;
285 case spirv::Decoration::DescriptorSet:
286 case spirv::Decoration::Binding:
287 if (words.size() != 3) {
288 return emitError(loc: unknownLoc, message: "OpDecorate with ")
289 << decorationName << " needs a single integer literal";
290 }
291 decorations[words[0]].set(
292 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
293 break;
294 case spirv::Decoration::BuiltIn:
295 if (words.size() != 3) {
296 return emitError(loc: unknownLoc, message: "OpDecorate with ")
297 << decorationName << " needs a single integer literal";
298 }
299 decorations[words[0]].set(
300 symbol, opBuilder.getStringAttr(
301 bytes: stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
302 break;
303 case spirv::Decoration::ArrayStride:
304 if (words.size() != 3) {
305 return emitError(loc: unknownLoc, message: "OpDecorate with ")
306 << decorationName << " needs a single integer literal";
307 }
308 typeDecorations[words[0]] = words[2];
309 break;
310 case spirv::Decoration::LinkageAttributes: {
311 if (words.size() < 4) {
312 return emitError(loc: unknownLoc, message: "OpDecorate with ")
313 << decorationName
314 << " needs at least 1 string and 1 integer literal";
315 }
316 // LinkageAttributes has two parameters ["linkageName", linkageType]
317 // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
318 // "linkageName" is a stringliteral encoded as uint32_t,
319 // hence the size of name is variable length which results in words.size()
320 // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
321 // 3 + ceildiv(strlen(name), 4).
322 unsigned wordIndex = 2;
323 auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
324 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
325 static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
326 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
327 StringAttr::get(context, linkageName), linkageTypeAttr);
328 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
329 break;
330 }
331 case spirv::Decoration::Aliased:
332 case spirv::Decoration::AliasedPointer:
333 case spirv::Decoration::Block:
334 case spirv::Decoration::BufferBlock:
335 case spirv::Decoration::Flat:
336 case spirv::Decoration::NonReadable:
337 case spirv::Decoration::NonWritable:
338 case spirv::Decoration::NoPerspective:
339 case spirv::Decoration::NoSignedWrap:
340 case spirv::Decoration::NoUnsignedWrap:
341 case spirv::Decoration::RelaxedPrecision:
342 case spirv::Decoration::Restrict:
343 case spirv::Decoration::RestrictPointer:
344 case spirv::Decoration::NoContraction:
345 case spirv::Decoration::Constant:
346 if (words.size() != 2) {
347 return emitError(loc: unknownLoc, message: "OpDecoration with ")
348 << decorationName << "needs a single target <id>";
349 }
350 // Block decoration does not affect spirv.struct type, but is still stored
351 // for verification.
352 // TODO: Update StructType to contain this information since
353 // it is needed for many validation rules.
354 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
355 break;
356 case spirv::Decoration::Location:
357 case spirv::Decoration::SpecId:
358 if (words.size() != 3) {
359 return emitError(loc: unknownLoc, message: "OpDecoration with ")
360 << decorationName << "needs a single integer literal";
361 }
362 decorations[words[0]].set(
363 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
364 break;
365 case spirv::Decoration::CacheControlLoadINTEL: {
366 LogicalResult res = deserializeCacheControlDecoration<
367 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
368 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
369 "load");
370 if (failed(Result: res))
371 return res;
372 break;
373 }
374 case spirv::Decoration::CacheControlStoreINTEL: {
375 LogicalResult res = deserializeCacheControlDecoration<
376 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
377 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
378 "store");
379 if (failed(Result: res))
380 return res;
381 break;
382 }
383 default:
384 return emitError(loc: unknownLoc, message: "unhandled Decoration : '") << decorationName;
385 }
386 return success();
387}
388
389LogicalResult
390spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
391 // The binary layout of OpMemberDecorate is different comparing to OpDecorate
392 if (words.size() < 3) {
393 return emitError(loc: unknownLoc,
394 message: "OpMemberDecorate must have at least 3 operands");
395 }
396
397 auto decoration = static_cast<spirv::Decoration>(words[2]);
398 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
399 return emitError(loc: unknownLoc,
400 message: " missing offset specification in OpMemberDecorate with "
401 "Offset decoration");
402 }
403 ArrayRef<uint32_t> decorationOperands;
404 if (words.size() > 3) {
405 decorationOperands = words.slice(N: 3);
406 }
407 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
408 return success();
409}
410
411LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
412 if (words.size() < 3) {
413 return emitError(loc: unknownLoc, message: "OpMemberName must have at least 3 operands");
414 }
415 unsigned wordIndex = 2;
416 auto name = decodeStringLiteral(words, wordIndex);
417 if (wordIndex != words.size()) {
418 return emitError(loc: unknownLoc,
419 message: "unexpected trailing words in OpMemberName instruction");
420 }
421 memberNameMap[words[0]][words[1]] = name;
422 return success();
423}
424
425LogicalResult spirv::Deserializer::setFunctionArgAttrs(
426 uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
427 if (!decorations.contains(Val: argID)) {
428 argAttrs[argIndex] = DictionaryAttr::get(context, {});
429 return success();
430 }
431
432 spirv::DecorationAttr foundDecorationAttr;
433 for (NamedAttribute decAttr : decorations[argID]) {
434 for (auto decoration :
435 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
436 spirv::Decoration::AliasedPointer,
437 spirv::Decoration::RestrictPointer}) {
438
439 if (decAttr.getName() !=
440 getSymbolDecoration(stringifyDecoration(decoration)))
441 continue;
442
443 if (foundDecorationAttr)
444 return emitError(unknownLoc,
445 "more than one Aliased/Restrict decorations for "
446 "function argument with result <id> ")
447 << argID;
448
449 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
450 break;
451 }
452
453 if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
454 spirv::Decoration::RelaxedPrecision))) {
455 // TODO: Current implementation supports only one decoration per function
456 // parameter so RelaxedPrecision cannot be applied at the same time as,
457 // for example, Aliased/Restrict/etc. This should be relaxed to allow any
458 // combination of decoration allowed by the spec to be supported.
459 if (foundDecorationAttr)
460 return emitError(loc: unknownLoc, message: "already found a decoration for function "
461 "argument with result <id> ")
462 << argID;
463
464 foundDecorationAttr = spirv::DecorationAttr::get(
465 context, spirv::Decoration::RelaxedPrecision);
466 }
467 }
468
469 if (!foundDecorationAttr)
470 return emitError(loc: unknownLoc, message: "unimplemented decoration support for "
471 "function argument with result <id> ")
472 << argID;
473
474 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
475 foundDecorationAttr);
476 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
477 return success();
478}
479
480LogicalResult
481spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
482 if (curFunction) {
483 return emitError(loc: unknownLoc, message: "found function inside function");
484 }
485
486 // Get the result type
487 if (operands.size() != 4) {
488 return emitError(loc: unknownLoc, message: "OpFunction must have 4 parameters");
489 }
490 Type resultType = getType(id: operands[0]);
491 if (!resultType) {
492 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
493 << operands[0];
494 }
495
496 uint32_t fnID = operands[1];
497 if (funcMap.count(fnID)) {
498 return emitError(loc: unknownLoc, message: "duplicate function definition/declaration");
499 }
500
501 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
502 if (!fnControl) {
503 return emitError(loc: unknownLoc, message: "unknown Function Control: ") << operands[2];
504 }
505
506 Type fnType = getType(id: operands[3]);
507 if (!fnType || !isa<FunctionType>(Val: fnType)) {
508 return emitError(loc: unknownLoc, message: "unknown function type from <id> ")
509 << operands[3];
510 }
511 auto functionType = cast<FunctionType>(fnType);
512
513 if ((isVoidType(type: resultType) && functionType.getNumResults() != 0) ||
514 (functionType.getNumResults() == 1 &&
515 functionType.getResult(0) != resultType)) {
516 return emitError(loc: unknownLoc, message: "mismatch in function type ")
517 << functionType << " and return type " << resultType << " specified";
518 }
519
520 std::string fnName = getFunctionSymbol(id: fnID);
521 auto funcOp = opBuilder.create<spirv::FuncOp>(
522 unknownLoc, fnName, functionType, fnControl.value());
523 // Processing other function attributes.
524 if (decorations.count(Val: fnID)) {
525 for (auto attr : decorations[fnID].getAttrs()) {
526 funcOp->setAttr(attr.getName(), attr.getValue());
527 }
528 }
529 curFunction = funcMap[fnID] = funcOp;
530 auto *entryBlock = funcOp.addEntryBlock();
531 LLVM_DEBUG({
532 logger.startLine()
533 << "//===-------------------------------------------===//\n";
534 logger.startLine() << "[fn] name: " << fnName << "\n";
535 logger.startLine() << "[fn] type: " << fnType << "\n";
536 logger.startLine() << "[fn] ID: " << fnID << "\n";
537 logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
538 logger.indent();
539 });
540
541 SmallVector<Attribute> argAttrs;
542 argAttrs.resize(functionType.getNumInputs());
543
544 // Parse the op argument instructions
545 if (functionType.getNumInputs()) {
546 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
547 auto argType = functionType.getInput(i);
548 spirv::Opcode opcode = spirv::Opcode::OpNop;
549 ArrayRef<uint32_t> operands;
550 if (failed(sliceInstruction(opcode, operands,
551 spirv::Opcode::OpFunctionParameter))) {
552 return failure();
553 }
554 if (opcode != spirv::Opcode::OpFunctionParameter) {
555 return emitError(
556 loc: unknownLoc,
557 message: "missing OpFunctionParameter instruction for argument ")
558 << i;
559 }
560 if (operands.size() != 2) {
561 return emitError(
562 loc: unknownLoc,
563 message: "expected result type and result <id> for OpFunctionParameter");
564 }
565 auto argDefinedType = getType(id: operands[0]);
566 if (!argDefinedType || argDefinedType != argType) {
567 return emitError(loc: unknownLoc,
568 message: "mismatch in argument type between function type "
569 "definition ")
570 << functionType << " and argument type definition "
571 << argDefinedType << " at argument " << i;
572 }
573 if (getValue(id: operands[1])) {
574 return emitError(loc: unknownLoc, message: "duplicate definition of result <id> ")
575 << operands[1];
576 }
577 if (failed(Result: setFunctionArgAttrs(argID: operands[1], argAttrs, argIndex: i))) {
578 return failure();
579 }
580
581 auto argValue = funcOp.getArgument(i);
582 valueMap[operands[1]] = argValue;
583 }
584 }
585
586 if (llvm::any_of(argAttrs, [](Attribute attr) {
587 auto argAttr = cast<DictionaryAttr>(attr);
588 return !argAttr.empty();
589 }))
590 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
591
592 // entryBlock is needed to access the arguments, Once that is done, we can
593 // erase the block for functions with 'Import' LinkageAttributes, since these
594 // are essentially function declarations, so they have no body.
595 auto linkageAttr = funcOp.getLinkageAttributes();
596 auto hasImportLinkage =
597 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
598 spirv::LinkageType::Import);
599 if (hasImportLinkage)
600 funcOp.eraseBody();
601
602 // RAII guard to reset the insertion point to the module's region after
603 // deserializing the body of this function.
604 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
605
606 spirv::Opcode opcode = spirv::Opcode::OpNop;
607 ArrayRef<uint32_t> instOperands;
608
609 // Special handling for the entry block. We need to make sure it starts with
610 // an OpLabel instruction. The entry block takes the same parameters as the
611 // function. All other blocks do not take any parameter. We have already
612 // created the entry block, here we need to register it to the correct label
613 // <id>.
614 if (failed(sliceInstruction(opcode, instOperands,
615 spirv::Opcode::OpFunctionEnd))) {
616 return failure();
617 }
618 if (opcode == spirv::Opcode::OpFunctionEnd) {
619 return processFunctionEnd(operands: instOperands);
620 }
621 if (opcode != spirv::Opcode::OpLabel) {
622 return emitError(loc: unknownLoc, message: "a basic block must start with OpLabel");
623 }
624 if (instOperands.size() != 1) {
625 return emitError(loc: unknownLoc, message: "OpLabel should only have result <id>");
626 }
627 blockMap[instOperands[0]] = entryBlock;
628 if (failed(Result: processLabel(operands: instOperands))) {
629 return failure();
630 }
631
632 // Then process all the other instructions in the function until we hit
633 // OpFunctionEnd.
634 while (succeeded(sliceInstruction(opcode, instOperands,
635 spirv::Opcode::OpFunctionEnd)) &&
636 opcode != spirv::Opcode::OpFunctionEnd) {
637 if (failed(processInstruction(opcode, instOperands))) {
638 return failure();
639 }
640 }
641 if (opcode != spirv::Opcode::OpFunctionEnd) {
642 return failure();
643 }
644
645 return processFunctionEnd(operands: instOperands);
646}
647
648LogicalResult
649spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
650 // Process OpFunctionEnd.
651 if (!operands.empty()) {
652 return emitError(loc: unknownLoc, message: "unexpected operands for OpFunctionEnd");
653 }
654
655 // Wire up block arguments from OpPhi instructions.
656 // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
657 // ops.
658 if (failed(Result: wireUpBlockArgument()) || failed(Result: structurizeControlFlow())) {
659 return failure();
660 }
661
662 curBlock = nullptr;
663 curFunction = std::nullopt;
664
665 LLVM_DEBUG({
666 logger.unindent();
667 logger.startLine()
668 << "//===-------------------------------------------===//\n";
669 });
670 return success();
671}
672
673std::optional<std::pair<Attribute, Type>>
674spirv::Deserializer::getConstant(uint32_t id) {
675 auto constIt = constantMap.find(Val: id);
676 if (constIt == constantMap.end())
677 return std::nullopt;
678 return constIt->getSecond();
679}
680
681std::optional<spirv::SpecConstOperationMaterializationInfo>
682spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
683 auto constIt = specConstOperationMap.find(Val: id);
684 if (constIt == specConstOperationMap.end())
685 return std::nullopt;
686 return constIt->getSecond();
687}
688
689std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
690 auto funcName = nameMap.lookup(Val: id).str();
691 if (funcName.empty()) {
692 funcName = "spirv_fn_" + std::to_string(val: id);
693 }
694 return funcName;
695}
696
697std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
698 auto constName = nameMap.lookup(Val: id).str();
699 if (constName.empty()) {
700 constName = "spirv_spec_const_" + std::to_string(val: id);
701 }
702 return constName;
703}
704
705spirv::SpecConstantOp
706spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
707 TypedAttr defaultValue) {
708 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(id: resultID));
709 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
710 defaultValue);
711 if (decorations.count(Val: resultID)) {
712 for (auto attr : decorations[resultID].getAttrs())
713 op->setAttr(attr.getName(), attr.getValue());
714 }
715 specConstMap[resultID] = op;
716 return op;
717}
718
719LogicalResult
720spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
721 unsigned wordIndex = 0;
722 if (operands.size() < 3) {
723 return emitError(
724 loc: unknownLoc,
725 message: "OpVariable needs at least 3 operands, type, <id> and storage class");
726 }
727
728 // Result Type.
729 auto type = getType(id: operands[wordIndex]);
730 if (!type) {
731 return emitError(loc: unknownLoc, message: "unknown result type <id> : ")
732 << operands[wordIndex];
733 }
734 auto ptrType = dyn_cast<spirv::PointerType>(Val&: type);
735 if (!ptrType) {
736 return emitError(loc: unknownLoc,
737 message: "expected a result type <id> to be a spirv.ptr, found : ")
738 << type;
739 }
740 wordIndex++;
741
742 // Result <id>.
743 auto variableID = operands[wordIndex];
744 auto variableName = nameMap.lookup(Val: variableID).str();
745 if (variableName.empty()) {
746 variableName = "spirv_var_" + std::to_string(val: variableID);
747 }
748 wordIndex++;
749
750 // Storage class.
751 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
752 if (ptrType.getStorageClass() != storageClass) {
753 return emitError(loc: unknownLoc, message: "mismatch in storage class of pointer type ")
754 << type << " and that specified in OpVariable instruction : "
755 << stringifyStorageClass(storageClass);
756 }
757 wordIndex++;
758
759 // Initializer.
760 FlatSymbolRefAttr initializer = nullptr;
761
762 if (wordIndex < operands.size()) {
763 Operation *op = nullptr;
764
765 if (auto initOp = getGlobalVariable(operands[wordIndex]))
766 op = initOp;
767 else if (auto initOp = getSpecConstant(operands[wordIndex]))
768 op = initOp;
769 else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
770 op = initOp;
771 else
772 return emitError(loc: unknownLoc, message: "unknown <id> ")
773 << operands[wordIndex] << "used as initializer";
774
775 initializer = SymbolRefAttr::get(op);
776 wordIndex++;
777 }
778 if (wordIndex != operands.size()) {
779 return emitError(loc: unknownLoc,
780 message: "found more operands than expected when deserializing "
781 "OpVariable instruction, only ")
782 << wordIndex << " of " << operands.size() << " processed";
783 }
784 auto loc = createFileLineColLoc(opBuilder);
785 auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
786 loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
787 initializer);
788
789 // Decorations.
790 if (decorations.count(Val: variableID)) {
791 for (auto attr : decorations[variableID].getAttrs())
792 varOp->setAttr(attr.getName(), attr.getValue());
793 }
794 globalVariableMap[variableID] = varOp;
795 return success();
796}
797
798IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
799 auto constInfo = getConstant(id);
800 if (!constInfo) {
801 return nullptr;
802 }
803 return dyn_cast<IntegerAttr>(constInfo->first);
804}
805
806LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
807 if (operands.size() < 2) {
808 return emitError(loc: unknownLoc, message: "OpName needs at least 2 operands");
809 }
810 if (!nameMap.lookup(Val: operands[0]).empty()) {
811 return emitError(loc: unknownLoc, message: "duplicate name found for result <id> ")
812 << operands[0];
813 }
814 unsigned wordIndex = 1;
815 StringRef name = decodeStringLiteral(words: operands, wordIndex);
816 if (wordIndex != operands.size()) {
817 return emitError(loc: unknownLoc,
818 message: "unexpected trailing words in OpName instruction");
819 }
820 nameMap[operands[0]] = name;
821 return success();
822}
823
824//===----------------------------------------------------------------------===//
825// Type
826//===----------------------------------------------------------------------===//
827
828LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
829 ArrayRef<uint32_t> operands) {
830 if (operands.empty()) {
831 return emitError(unknownLoc, "type instruction with opcode ")
832 << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
833 }
834
835 /// TODO: Types might be forward declared in some instructions and need to be
836 /// handled appropriately.
837 if (typeMap.count(Val: operands[0])) {
838 return emitError(loc: unknownLoc, message: "duplicate definition for result <id> ")
839 << operands[0];
840 }
841
842 switch (opcode) {
843 case spirv::Opcode::OpTypeVoid:
844 if (operands.size() != 1)
845 return emitError(loc: unknownLoc, message: "OpTypeVoid must have no parameters");
846 typeMap[operands[0]] = opBuilder.getNoneType();
847 break;
848 case spirv::Opcode::OpTypeBool:
849 if (operands.size() != 1)
850 return emitError(loc: unknownLoc, message: "OpTypeBool must have no parameters");
851 typeMap[operands[0]] = opBuilder.getI1Type();
852 break;
853 case spirv::Opcode::OpTypeInt: {
854 if (operands.size() != 3)
855 return emitError(
856 loc: unknownLoc, message: "OpTypeInt must have bitwidth and signedness parameters");
857
858 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
859 // to preserve or validate.
860 // 0 indicates unsigned, or no signedness semantics
861 // 1 indicates signed semantics."
862 //
863 // So we cannot differentiate signless and unsigned integers; always use
864 // signless semantics for such cases.
865 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
866 : IntegerType::SignednessSemantics::Signless;
867 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
868 } break;
869 case spirv::Opcode::OpTypeFloat: {
870 if (operands.size() != 2)
871 return emitError(loc: unknownLoc, message: "OpTypeFloat must have bitwidth parameter");
872
873 Type floatTy;
874 switch (operands[1]) {
875 case 16:
876 floatTy = opBuilder.getF16Type();
877 break;
878 case 32:
879 floatTy = opBuilder.getF32Type();
880 break;
881 case 64:
882 floatTy = opBuilder.getF64Type();
883 break;
884 default:
885 return emitError(loc: unknownLoc, message: "unsupported OpTypeFloat bitwidth: ")
886 << operands[1];
887 }
888 typeMap[operands[0]] = floatTy;
889 } break;
890 case spirv::Opcode::OpTypeVector: {
891 if (operands.size() != 3) {
892 return emitError(
893 loc: unknownLoc,
894 message: "OpTypeVector must have element type and count parameters");
895 }
896 Type elementTy = getType(id: operands[1]);
897 if (!elementTy) {
898 return emitError(loc: unknownLoc, message: "OpTypeVector references undefined <id> ")
899 << operands[1];
900 }
901 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
902 } break;
903 case spirv::Opcode::OpTypePointer: {
904 return processOpTypePointer(operands);
905 } break;
906 case spirv::Opcode::OpTypeArray:
907 return processArrayType(operands);
908 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
909 return processCooperativeMatrixTypeKHR(operands);
910 case spirv::Opcode::OpTypeFunction:
911 return processFunctionType(operands);
912 case spirv::Opcode::OpTypeImage:
913 return processImageType(operands);
914 case spirv::Opcode::OpTypeSampledImage:
915 return processSampledImageType(operands);
916 case spirv::Opcode::OpTypeRuntimeArray:
917 return processRuntimeArrayType(operands);
918 case spirv::Opcode::OpTypeStruct:
919 return processStructType(operands);
920 case spirv::Opcode::OpTypeMatrix:
921 return processMatrixType(operands);
922 default:
923 return emitError(loc: unknownLoc, message: "unhandled type instruction");
924 }
925 return success();
926}
927
928LogicalResult
929spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
930 if (operands.size() != 3)
931 return emitError(loc: unknownLoc, message: "OpTypePointer must have two parameters");
932
933 auto pointeeType = getType(id: operands[2]);
934 if (!pointeeType)
935 return emitError(loc: unknownLoc, message: "unknown OpTypePointer pointee type <id> ")
936 << operands[2];
937
938 uint32_t typePointerID = operands[0];
939 auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
940 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
941
942 for (auto *deferredStructIt = std::begin(cont&: deferredStructTypesInfos);
943 deferredStructIt != std::end(cont&: deferredStructTypesInfos);) {
944 for (auto *unresolvedMemberIt =
945 std::begin(cont&: deferredStructIt->unresolvedMemberTypes);
946 unresolvedMemberIt !=
947 std::end(cont&: deferredStructIt->unresolvedMemberTypes);) {
948 if (unresolvedMemberIt->first == typePointerID) {
949 // The newly constructed pointer type can resolve one of the
950 // deferred struct type members; update the memberTypes list and
951 // clean the unresolvedMemberTypes list accordingly.
952 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
953 typeMap[typePointerID];
954 unresolvedMemberIt =
955 deferredStructIt->unresolvedMemberTypes.erase(CI: unresolvedMemberIt);
956 } else {
957 ++unresolvedMemberIt;
958 }
959 }
960
961 if (deferredStructIt->unresolvedMemberTypes.empty()) {
962 // All deferred struct type members are now resolved, set the struct body.
963 auto structType = deferredStructIt->deferredStructType;
964
965 assert(structType && "expected a spirv::StructType");
966 assert(structType.isIdentified() && "expected an indentified struct");
967
968 if (failed(Result: structType.trySetBody(
969 memberTypes: deferredStructIt->memberTypes, offsetInfo: deferredStructIt->offsetInfo,
970 memberDecorations: deferredStructIt->memberDecorationsInfo)))
971 return failure();
972
973 deferredStructIt = deferredStructTypesInfos.erase(CI: deferredStructIt);
974 } else {
975 ++deferredStructIt;
976 }
977 }
978
979 return success();
980}
981
982LogicalResult
983spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
984 if (operands.size() != 3) {
985 return emitError(loc: unknownLoc,
986 message: "OpTypeArray must have element type and count parameters");
987 }
988
989 Type elementTy = getType(id: operands[1]);
990 if (!elementTy) {
991 return emitError(loc: unknownLoc, message: "OpTypeArray references undefined <id> ")
992 << operands[1];
993 }
994
995 unsigned count = 0;
996 // TODO: The count can also come frome a specialization constant.
997 auto countInfo = getConstant(id: operands[2]);
998 if (!countInfo) {
999 return emitError(loc: unknownLoc, message: "OpTypeArray count <id> ")
1000 << operands[2] << "can only come from normal constant right now";
1001 }
1002
1003 if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1004 count = intVal.getValue().getZExtValue();
1005 } else {
1006 return emitError(loc: unknownLoc, message: "OpTypeArray count must come from a "
1007 "scalar integer constant instruction");
1008 }
1009
1010 typeMap[operands[0]] = spirv::ArrayType::get(
1011 elementType: elementTy, elementCount: count, stride: typeDecorations.lookup(Val: operands[0]));
1012 return success();
1013}
1014
1015LogicalResult
1016spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
1017 assert(!operands.empty() && "No operands for processing function type");
1018 if (operands.size() == 1) {
1019 return emitError(loc: unknownLoc, message: "missing return type for OpTypeFunction");
1020 }
1021 auto returnType = getType(id: operands[1]);
1022 if (!returnType) {
1023 return emitError(loc: unknownLoc, message: "unknown return type in OpTypeFunction");
1024 }
1025 SmallVector<Type, 1> argTypes;
1026 for (size_t i = 2, e = operands.size(); i < e; ++i) {
1027 auto ty = getType(id: operands[i]);
1028 if (!ty) {
1029 return emitError(loc: unknownLoc, message: "unknown argument type in OpTypeFunction");
1030 }
1031 argTypes.push_back(Elt: ty);
1032 }
1033 ArrayRef<Type> returnTypes;
1034 if (!isVoidType(type: returnType)) {
1035 returnTypes = llvm::ArrayRef(returnType);
1036 }
1037 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1038 return success();
1039}
1040
1041LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1042 ArrayRef<uint32_t> operands) {
1043 if (operands.size() != 6) {
1044 return emitError(loc: unknownLoc,
1045 message: "OpTypeCooperativeMatrixKHR must have element type, "
1046 "scope, row and column parameters, and use");
1047 }
1048
1049 Type elementTy = getType(id: operands[1]);
1050 if (!elementTy) {
1051 return emitError(loc: unknownLoc,
1052 message: "OpTypeCooperativeMatrixKHR references undefined <id> ")
1053 << operands[1];
1054 }
1055
1056 std::optional<spirv::Scope> scope =
1057 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1058 if (!scope) {
1059 return emitError(
1060 loc: unknownLoc,
1061 message: "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1062 << operands[2];
1063 }
1064
1065 IntegerAttr rowsAttr = getConstantInt(operands[3]);
1066 IntegerAttr columnsAttr = getConstantInt(operands[4]);
1067 IntegerAttr useAttr = getConstantInt(operands[5]);
1068
1069 if (!rowsAttr)
1070 return emitError(loc: unknownLoc, message: "OpTypeCooperativeMatrixKHR `Rows` references "
1071 "undefined constant <id> ")
1072 << operands[3];
1073
1074 if (!columnsAttr)
1075 return emitError(loc: unknownLoc, message: "OpTypeCooperativeMatrixKHR `Columns` "
1076 "references undefined constant <id> ")
1077 << operands[4];
1078
1079 if (!useAttr)
1080 return emitError(loc: unknownLoc, message: "OpTypeCooperativeMatrixKHR `Use` references "
1081 "undefined constant <id> ")
1082 << operands[5];
1083
1084 unsigned rows = rowsAttr.getInt();
1085 unsigned columns = columnsAttr.getInt();
1086
1087 std::optional<spirv::CooperativeMatrixUseKHR> use =
1088 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1089 if (!use) {
1090 return emitError(
1091 loc: unknownLoc,
1092 message: "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1093 << operands[5];
1094 }
1095
1096 typeMap[operands[0]] =
1097 spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1098 return success();
1099}
1100
1101LogicalResult
1102spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1103 if (operands.size() != 2) {
1104 return emitError(loc: unknownLoc, message: "OpTypeRuntimeArray must have two operands");
1105 }
1106 Type memberType = getType(id: operands[1]);
1107 if (!memberType) {
1108 return emitError(loc: unknownLoc,
1109 message: "OpTypeRuntimeArray references undefined <id> ")
1110 << operands[1];
1111 }
1112 typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1113 elementType: memberType, stride: typeDecorations.lookup(Val: operands[0]));
1114 return success();
1115}
1116
1117LogicalResult
1118spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1119 // TODO: Find a way to handle identified structs when debug info is stripped.
1120
1121 if (operands.empty()) {
1122 return emitError(loc: unknownLoc, message: "OpTypeStruct must have at least result <id>");
1123 }
1124
1125 if (operands.size() == 1) {
1126 // Handle empty struct.
1127 typeMap[operands[0]] =
1128 spirv::StructType::getEmpty(context, identifier: nameMap.lookup(Val: operands[0]).str());
1129 return success();
1130 }
1131
1132 // First element is operand ID, second element is member index in the struct.
1133 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1134 SmallVector<Type, 4> memberTypes;
1135
1136 for (auto op : llvm::drop_begin(RangeOrContainer&: operands, N: 1)) {
1137 Type memberType = getType(id: op);
1138 bool typeForwardPtr = (typeForwardPointerIDs.count(key: op) != 0);
1139
1140 if (!memberType && !typeForwardPtr)
1141 return emitError(loc: unknownLoc, message: "OpTypeStruct references undefined <id> ")
1142 << op;
1143
1144 if (!memberType)
1145 unresolvedMemberTypes.emplace_back(Args&: op, Args: memberTypes.size());
1146
1147 memberTypes.push_back(Elt: memberType);
1148 }
1149
1150 SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
1151 SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
1152 if (memberDecorationMap.count(operands[0])) {
1153 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1154 for (auto memberIndex : llvm::seq<uint32_t>(Begin: 0, End: memberTypes.size())) {
1155 if (allMemberDecorations.count(memberIndex)) {
1156 for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1157 // Check for offset.
1158 if (memberDecoration.first == spirv::Decoration::Offset) {
1159 // If offset info is empty, resize to the number of members;
1160 if (offsetInfo.empty()) {
1161 offsetInfo.resize(memberTypes.size());
1162 }
1163 offsetInfo[memberIndex] = memberDecoration.second[0];
1164 } else {
1165 if (!memberDecoration.second.empty()) {
1166 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
1167 memberDecoration.first,
1168 memberDecoration.second[0]);
1169 } else {
1170 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
1171 memberDecoration.first, 0);
1172 }
1173 }
1174 }
1175 }
1176 }
1177 }
1178
1179 uint32_t structID = operands[0];
1180 std::string structIdentifier = nameMap.lookup(Val: structID).str();
1181
1182 if (structIdentifier.empty()) {
1183 assert(unresolvedMemberTypes.empty() &&
1184 "didn't expect unresolved member types");
1185 typeMap[structID] =
1186 spirv::StructType::get(memberTypes, offsetInfo, memberDecorations: memberDecorationsInfo);
1187 } else {
1188 auto structTy = spirv::StructType::getIdentified(context, identifier: structIdentifier);
1189 typeMap[structID] = structTy;
1190
1191 if (!unresolvedMemberTypes.empty())
1192 deferredStructTypesInfos.push_back(Elt: {.deferredStructType: structTy, .unresolvedMemberTypes: unresolvedMemberTypes,
1193 .memberTypes: memberTypes, .offsetInfo: offsetInfo,
1194 .memberDecorationsInfo: memberDecorationsInfo});
1195 else if (failed(Result: structTy.trySetBody(memberTypes, offsetInfo,
1196 memberDecorations: memberDecorationsInfo)))
1197 return failure();
1198 }
1199
1200 // TODO: Update StructType to have member name as attribute as
1201 // well.
1202 return success();
1203}
1204
1205LogicalResult
1206spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1207 if (operands.size() != 3) {
1208 // Three operands are needed: result_id, column_type, and column_count
1209 return emitError(loc: unknownLoc, message: "OpTypeMatrix must have 3 operands"
1210 " (result_id, column_type, and column_count)");
1211 }
1212 // Matrix columns must be of vector type
1213 Type elementTy = getType(id: operands[1]);
1214 if (!elementTy) {
1215 return emitError(loc: unknownLoc,
1216 message: "OpTypeMatrix references undefined column type.")
1217 << operands[1];
1218 }
1219
1220 uint32_t colsCount = operands[2];
1221 typeMap[operands[0]] = spirv::MatrixType::get(columnType: elementTy, columnCount: colsCount);
1222 return success();
1223}
1224
1225LogicalResult
1226spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1227 if (operands.size() != 2)
1228 return emitError(loc: unknownLoc,
1229 message: "OpTypeForwardPointer instruction must have two operands");
1230
1231 typeForwardPointerIDs.insert(X: operands[0]);
1232 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1233 // instruction that defines the actual type.
1234
1235 return success();
1236}
1237
1238LogicalResult
1239spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1240 // TODO: Add support for Access Qualifier.
1241 if (operands.size() != 8)
1242 return emitError(
1243 loc: unknownLoc,
1244 message: "OpTypeImage with non-eight operands are not supported yet");
1245
1246 Type elementTy = getType(id: operands[1]);
1247 if (!elementTy)
1248 return emitError(loc: unknownLoc, message: "OpTypeImage references undefined <id>: ")
1249 << operands[1];
1250
1251 auto dim = spirv::symbolizeDim(operands[2]);
1252 if (!dim)
1253 return emitError(loc: unknownLoc, message: "unknown Dim for OpTypeImage: ")
1254 << operands[2];
1255
1256 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1257 if (!depthInfo)
1258 return emitError(loc: unknownLoc, message: "unknown Depth for OpTypeImage: ")
1259 << operands[3];
1260
1261 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1262 if (!arrayedInfo)
1263 return emitError(loc: unknownLoc, message: "unknown Arrayed for OpTypeImage: ")
1264 << operands[4];
1265
1266 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1267 if (!samplingInfo)
1268 return emitError(loc: unknownLoc, message: "unknown MS for OpTypeImage: ") << operands[5];
1269
1270 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1271 if (!samplerUseInfo)
1272 return emitError(loc: unknownLoc, message: "unknown Sampled for OpTypeImage: ")
1273 << operands[6];
1274
1275 auto format = spirv::symbolizeImageFormat(operands[7]);
1276 if (!format)
1277 return emitError(loc: unknownLoc, message: "unknown Format for OpTypeImage: ")
1278 << operands[7];
1279
1280 typeMap[operands[0]] = spirv::ImageType::get(
1281 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1282 samplingInfo.value(), samplerUseInfo.value(), format.value());
1283 return success();
1284}
1285
1286LogicalResult
1287spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1288 if (operands.size() != 2)
1289 return emitError(loc: unknownLoc, message: "OpTypeSampledImage must have two operands");
1290
1291 Type elementTy = getType(id: operands[1]);
1292 if (!elementTy)
1293 return emitError(loc: unknownLoc,
1294 message: "OpTypeSampledImage references undefined <id>: ")
1295 << operands[1];
1296
1297 typeMap[operands[0]] = spirv::SampledImageType::get(imageType: elementTy);
1298 return success();
1299}
1300
1301//===----------------------------------------------------------------------===//
1302// Constant
1303//===----------------------------------------------------------------------===//
1304
1305LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1306 bool isSpec) {
1307 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1308
1309 if (operands.size() < 2) {
1310 return emitError(loc: unknownLoc)
1311 << opname << " must have type <id> and result <id>";
1312 }
1313 if (operands.size() < 3) {
1314 return emitError(loc: unknownLoc)
1315 << opname << " must have at least 1 more parameter";
1316 }
1317
1318 Type resultType = getType(id: operands[0]);
1319 if (!resultType) {
1320 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1321 << operands[0];
1322 }
1323
1324 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1325 if (bitwidth == 64) {
1326 if (operands.size() == 4) {
1327 return success();
1328 }
1329 return emitError(loc: unknownLoc)
1330 << opname << " should have 2 parameters for 64-bit values";
1331 }
1332 if (bitwidth <= 32) {
1333 if (operands.size() == 3) {
1334 return success();
1335 }
1336
1337 return emitError(loc: unknownLoc)
1338 << opname
1339 << " should have 1 parameter for values with no more than 32 bits";
1340 }
1341 return emitError(loc: unknownLoc, message: "unsupported OpConstant bitwidth: ")
1342 << bitwidth;
1343 };
1344
1345 auto resultID = operands[1];
1346
1347 if (auto intType = dyn_cast<IntegerType>(resultType)) {
1348 auto bitwidth = intType.getWidth();
1349 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1350 return failure();
1351 }
1352
1353 APInt value;
1354 if (bitwidth == 64) {
1355 // 64-bit integers are represented with two SPIR-V words. According to
1356 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1357 // literal’s low-order words appear first."
1358 struct DoubleWord {
1359 uint32_t word1;
1360 uint32_t word2;
1361 } words = {.word1: operands[2], .word2: operands[3]};
1362 value = APInt(64, llvm::bit_cast<uint64_t>(from: words), /*isSigned=*/true);
1363 } else if (bitwidth <= 32) {
1364 value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1365 /*implicitTrunc=*/true);
1366 }
1367
1368 auto attr = opBuilder.getIntegerAttr(intType, value);
1369
1370 if (isSpec) {
1371 createSpecConstant(unknownLoc, resultID, attr);
1372 } else {
1373 // For normal constants, we just record the attribute (and its type) for
1374 // later materialization at use sites.
1375 constantMap.try_emplace(resultID, attr, intType);
1376 }
1377
1378 return success();
1379 }
1380
1381 if (auto floatType = dyn_cast<FloatType>(resultType)) {
1382 auto bitwidth = floatType.getWidth();
1383 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1384 return failure();
1385 }
1386
1387 APFloat value(0.f);
1388 if (floatType.isF64()) {
1389 // Double values are represented with two SPIR-V words. According to
1390 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1391 // literal’s low-order words appear first."
1392 struct DoubleWord {
1393 uint32_t word1;
1394 uint32_t word2;
1395 } words = {.word1: operands[2], .word2: operands[3]};
1396 value = APFloat(llvm::bit_cast<double>(from: words));
1397 } else if (floatType.isF32()) {
1398 value = APFloat(llvm::bit_cast<float>(from: operands[2]));
1399 } else if (floatType.isF16()) {
1400 APInt data(16, operands[2]);
1401 value = APFloat(APFloat::IEEEhalf(), data);
1402 }
1403
1404 auto attr = opBuilder.getFloatAttr(floatType, value);
1405 if (isSpec) {
1406 createSpecConstant(unknownLoc, resultID, attr);
1407 } else {
1408 // For normal constants, we just record the attribute (and its type) for
1409 // later materialization at use sites.
1410 constantMap.try_emplace(resultID, attr, floatType);
1411 }
1412
1413 return success();
1414 }
1415
1416 return emitError(loc: unknownLoc, message: "OpConstant can only generate values of "
1417 "scalar integer or floating-point type");
1418}
1419
1420LogicalResult spirv::Deserializer::processConstantBool(
1421 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1422 if (operands.size() != 2) {
1423 return emitError(loc: unknownLoc, message: "Op")
1424 << (isSpec ? "Spec" : "") << "Constant"
1425 << (isTrue ? "True" : "False")
1426 << " must have type <id> and result <id>";
1427 }
1428
1429 auto attr = opBuilder.getBoolAttr(value: isTrue);
1430 auto resultID = operands[1];
1431 if (isSpec) {
1432 createSpecConstant(unknownLoc, resultID, attr);
1433 } else {
1434 // For normal constants, we just record the attribute (and its type) for
1435 // later materialization at use sites.
1436 constantMap.try_emplace(Key: resultID, Args&: attr, Args: opBuilder.getI1Type());
1437 }
1438
1439 return success();
1440}
1441
1442LogicalResult
1443spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1444 if (operands.size() < 2) {
1445 return emitError(loc: unknownLoc,
1446 message: "OpConstantComposite must have type <id> and result <id>");
1447 }
1448 if (operands.size() < 3) {
1449 return emitError(loc: unknownLoc,
1450 message: "OpConstantComposite must have at least 1 parameter");
1451 }
1452
1453 Type resultType = getType(id: operands[0]);
1454 if (!resultType) {
1455 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1456 << operands[0];
1457 }
1458
1459 SmallVector<Attribute, 4> elements;
1460 elements.reserve(N: operands.size() - 2);
1461 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1462 auto elementInfo = getConstant(id: operands[i]);
1463 if (!elementInfo) {
1464 return emitError(loc: unknownLoc, message: "OpConstantComposite component <id> ")
1465 << operands[i] << " must come from a normal constant";
1466 }
1467 elements.push_back(Elt: elementInfo->first);
1468 }
1469
1470 auto resultID = operands[1];
1471 if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
1472 auto attr = DenseElementsAttr::get(shapedType, elements);
1473 // For normal constants, we just record the attribute (and its type) for
1474 // later materialization at use sites.
1475 constantMap.try_emplace(resultID, attr, shapedType);
1476 } else if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: resultType)) {
1477 auto attr = opBuilder.getArrayAttr(elements);
1478 constantMap.try_emplace(resultID, attr, resultType);
1479 } else {
1480 return emitError(loc: unknownLoc, message: "unsupported OpConstantComposite type: ")
1481 << resultType;
1482 }
1483
1484 return success();
1485}
1486
1487LogicalResult
1488spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1489 if (operands.size() < 2) {
1490 return emitError(loc: unknownLoc,
1491 message: "OpConstantComposite must have type <id> and result <id>");
1492 }
1493 if (operands.size() < 3) {
1494 return emitError(loc: unknownLoc,
1495 message: "OpConstantComposite must have at least 1 parameter");
1496 }
1497
1498 Type resultType = getType(id: operands[0]);
1499 if (!resultType) {
1500 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1501 << operands[0];
1502 }
1503
1504 auto resultID = operands[1];
1505 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(id: resultID));
1506
1507 SmallVector<Attribute, 4> elements;
1508 elements.reserve(N: operands.size() - 2);
1509 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1510 auto elementInfo = getSpecConstant(operands[i]);
1511 elements.push_back(SymbolRefAttr::get(elementInfo));
1512 }
1513
1514 auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1515 unknownLoc, TypeAttr::get(resultType), symName,
1516 opBuilder.getArrayAttr(elements));
1517 specConstCompositeMap[resultID] = op;
1518
1519 return success();
1520}
1521
1522LogicalResult
1523spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1524 if (operands.size() < 3)
1525 return emitError(loc: unknownLoc, message: "OpConstantOperation must have type <id>, "
1526 "result <id>, and operand opcode");
1527
1528 uint32_t resultTypeID = operands[0];
1529
1530 if (!getType(id: resultTypeID))
1531 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1532 << resultTypeID;
1533
1534 uint32_t resultID = operands[1];
1535 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1536 auto emplaceResult = specConstOperationMap.try_emplace(
1537 Key: resultID,
1538 Args: SpecConstOperationMaterializationInfo{
1539 enclosedOpcode, resultTypeID,
1540 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1541
1542 if (!emplaceResult.second)
1543 return emitError(loc: unknownLoc, message: "value with <id>: ")
1544 << resultID << " is probably defined before.";
1545
1546 return success();
1547}
1548
1549Value spirv::Deserializer::materializeSpecConstantOperation(
1550 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1551 ArrayRef<uint32_t> enclosedOpOperands) {
1552
1553 Type resultType = getType(id: resultTypeID);
1554
1555 // Instructions wrapped by OpSpecConstantOp need an ID for their
1556 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1557 // dialect wrapped op. For that purpose, a new value map is created and "fake"
1558 // ID in that map is assigned to the result of the enclosed instruction. Note
1559 // that there is no need to update this fake ID since we only need to
1560 // reference the created Value for the enclosed op from the spv::YieldOp
1561 // created later in this method (both of which are the only values in their
1562 // region: the SpecConstantOperation's region). If we encounter another
1563 // SpecConstantOperation in the module, we simply re-use the fake ID since the
1564 // previous Value assigned to it isn't visible in the current scope anyway.
1565 DenseMap<uint32_t, Value> newValueMap;
1566 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1567 constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1568
1569 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1570 enclosedOpResultTypeAndOperands.push_back(Elt: resultTypeID);
1571 enclosedOpResultTypeAndOperands.push_back(Elt: fakeID);
1572 enclosedOpResultTypeAndOperands.append(in_start: enclosedOpOperands.begin(),
1573 in_end: enclosedOpOperands.end());
1574
1575 // Process enclosed instruction before creating the enclosing
1576 // specConstantOperation (and its region). This way, references to constants,
1577 // global variables, and spec constants will be materialized outside the new
1578 // op's region. For more info, see Deserializer::getValue's implementation.
1579 if (failed(
1580 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1581 return Value();
1582
1583 // Since the enclosed op is emitted in the current block, split it in a
1584 // separate new block.
1585 Block *enclosedBlock = curBlock->splitBlock(splitBeforeOp: &curBlock->back());
1586
1587 auto loc = createFileLineColLoc(opBuilder);
1588 auto specConstOperationOp =
1589 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1590
1591 Region &body = specConstOperationOp.getBody();
1592 // Move the new block into SpecConstantOperation's body.
1593 body.getBlocks().splice(where: body.end(), L2&: curBlock->getParent()->getBlocks(),
1594 first: Region::iterator(enclosedBlock));
1595 Block &block = body.back();
1596
1597 // RAII guard to reset the insertion point to the module's region after
1598 // deserializing the body of the specConstantOperation.
1599 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1600 opBuilder.setInsertionPointToEnd(&block);
1601
1602 opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1603 return specConstOperationOp.getResult();
1604}
1605
1606LogicalResult
1607spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1608 if (operands.size() != 2) {
1609 return emitError(loc: unknownLoc,
1610 message: "OpConstantNull must have type <id> and result <id>");
1611 }
1612
1613 Type resultType = getType(id: operands[0]);
1614 if (!resultType) {
1615 return emitError(loc: unknownLoc, message: "undefined result type from <id> ")
1616 << operands[0];
1617 }
1618
1619 auto resultID = operands[1];
1620 if (resultType.isIntOrFloat() || isa<VectorType>(Val: resultType)) {
1621 auto attr = opBuilder.getZeroAttr(resultType);
1622 // For normal constants, we just record the attribute (and its type) for
1623 // later materialization at use sites.
1624 constantMap.try_emplace(resultID, attr, resultType);
1625 return success();
1626 }
1627
1628 return emitError(loc: unknownLoc, message: "unsupported OpConstantNull type: ")
1629 << resultType;
1630}
1631
1632//===----------------------------------------------------------------------===//
1633// Control flow
1634//===----------------------------------------------------------------------===//
1635
1636Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1637 if (auto *block = getBlock(id)) {
1638 LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1639 << " @ " << block << "\n");
1640 return block;
1641 }
1642
1643 // We don't know where this block will be placed finally (in a
1644 // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1645 // function for now and sort out the proper place later.
1646 auto *block = curFunction->addBlock();
1647 LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1648 << " @ " << block << "\n");
1649 return blockMap[id] = block;
1650}
1651
1652LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1653 if (!curBlock) {
1654 return emitError(loc: unknownLoc, message: "OpBranch must appear inside a block");
1655 }
1656
1657 if (operands.size() != 1) {
1658 return emitError(loc: unknownLoc, message: "OpBranch must take exactly one target label");
1659 }
1660
1661 auto *target = getOrCreateBlock(id: operands[0]);
1662 auto loc = createFileLineColLoc(opBuilder);
1663 // The preceding instruction for the OpBranch instruction could be an
1664 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1665 // the same OpLine information.
1666 opBuilder.create<spirv::BranchOp>(loc, target);
1667
1668 clearDebugLine();
1669 return success();
1670}
1671
1672LogicalResult
1673spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1674 if (!curBlock) {
1675 return emitError(loc: unknownLoc,
1676 message: "OpBranchConditional must appear inside a block");
1677 }
1678
1679 if (operands.size() != 3 && operands.size() != 5) {
1680 return emitError(loc: unknownLoc,
1681 message: "OpBranchConditional must have condition, true label, "
1682 "false label, and optionally two branch weights");
1683 }
1684
1685 auto condition = getValue(id: operands[0]);
1686 auto *trueBlock = getOrCreateBlock(id: operands[1]);
1687 auto *falseBlock = getOrCreateBlock(id: operands[2]);
1688
1689 std::optional<std::pair<uint32_t, uint32_t>> weights;
1690 if (operands.size() == 5) {
1691 weights = std::make_pair(x: operands[3], y: operands[4]);
1692 }
1693 // The preceding instruction for the OpBranchConditional instruction could be
1694 // an OpSelectionMerge instruction, in this case they will have the same
1695 // OpLine information.
1696 auto loc = createFileLineColLoc(opBuilder);
1697 opBuilder.create<spirv::BranchConditionalOp>(
1698 loc, condition, trueBlock,
1699 /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1700 /*falseArguments=*/ArrayRef<Value>(), weights);
1701
1702 clearDebugLine();
1703 return success();
1704}
1705
1706LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1707 if (!curFunction) {
1708 return emitError(loc: unknownLoc, message: "OpLabel must appear inside a function");
1709 }
1710
1711 if (operands.size() != 1) {
1712 return emitError(loc: unknownLoc, message: "OpLabel should only have result <id>");
1713 }
1714
1715 auto labelID = operands[0];
1716 // We may have forward declared this block.
1717 auto *block = getOrCreateBlock(id: labelID);
1718 LLVM_DEBUG(logger.startLine()
1719 << "[block] populating block " << block << "\n");
1720 // If we have seen this block, make sure it was just a forward declaration.
1721 assert(block->empty() && "re-deserialize the same block!");
1722
1723 opBuilder.setInsertionPointToStart(block);
1724 blockMap[labelID] = curBlock = block;
1725
1726 return success();
1727}
1728
1729LogicalResult
1730spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1731 if (!curBlock) {
1732 return emitError(loc: unknownLoc, message: "OpSelectionMerge must appear in a block");
1733 }
1734
1735 if (operands.size() < 2) {
1736 return emitError(
1737 loc: unknownLoc,
1738 message: "OpSelectionMerge must specify merge target and selection control");
1739 }
1740
1741 auto *mergeBlock = getOrCreateBlock(id: operands[0]);
1742 auto loc = createFileLineColLoc(opBuilder);
1743 auto selectionControl = operands[1];
1744
1745 if (!blockMergeInfo.try_emplace(Key: curBlock, Args&: loc, Args&: selectionControl, Args&: mergeBlock)
1746 .second) {
1747 return emitError(
1748 loc: unknownLoc,
1749 message: "a block cannot have more than one OpSelectionMerge instruction");
1750 }
1751
1752 return success();
1753}
1754
1755LogicalResult
1756spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1757 if (!curBlock) {
1758 return emitError(loc: unknownLoc, message: "OpLoopMerge must appear in a block");
1759 }
1760
1761 if (operands.size() < 3) {
1762 return emitError(loc: unknownLoc, message: "OpLoopMerge must specify merge target, "
1763 "continue target and loop control");
1764 }
1765
1766 auto *mergeBlock = getOrCreateBlock(id: operands[0]);
1767 auto *continueBlock = getOrCreateBlock(id: operands[1]);
1768 auto loc = createFileLineColLoc(opBuilder);
1769 uint32_t loopControl = operands[2];
1770
1771 if (!blockMergeInfo
1772 .try_emplace(Key: curBlock, Args&: loc, Args&: loopControl, Args&: mergeBlock, Args&: continueBlock)
1773 .second) {
1774 return emitError(
1775 loc: unknownLoc,
1776 message: "a block cannot have more than one OpLoopMerge instruction");
1777 }
1778
1779 return success();
1780}
1781
1782LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1783 if (!curBlock) {
1784 return emitError(loc: unknownLoc, message: "OpPhi must appear in a block");
1785 }
1786
1787 if (operands.size() < 4) {
1788 return emitError(loc: unknownLoc, message: "OpPhi must specify result type, result <id>, "
1789 "and variable-parent pairs");
1790 }
1791
1792 // Create a block argument for this OpPhi instruction.
1793 Type blockArgType = getType(id: operands[0]);
1794 BlockArgument blockArg = curBlock->addArgument(type: blockArgType, loc: unknownLoc);
1795 valueMap[operands[1]] = blockArg;
1796 LLVM_DEBUG(logger.startLine()
1797 << "[phi] created block argument " << blockArg
1798 << " id = " << operands[1] << " of type " << blockArgType << "\n");
1799
1800 // For each (value, predecessor) pair, insert the value to the predecessor's
1801 // blockPhiInfo entry so later we can fix the block argument there.
1802 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1803 uint32_t value = operands[i];
1804 Block *predecessor = getOrCreateBlock(id: operands[i + 1]);
1805 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1806 blockPhiInfo[predecessorTargetPair].push_back(Elt: value);
1807 LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1808 << " with arg id = " << value << "\n");
1809 }
1810
1811 return success();
1812}
1813
1814namespace {
1815/// A class for putting all blocks in a structured selection/loop in a
1816/// spirv.mlir.selection/spirv.mlir.loop op.
1817class ControlFlowStructurizer {
1818public:
1819#ifndef NDEBUG
1820 ControlFlowStructurizer(Location loc, uint32_t control,
1821 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1822 Block *merge, Block *cont,
1823 llvm::ScopedPrinter &logger)
1824 : location(loc), control(control), blockMergeInfo(mergeInfo),
1825 headerBlock(header), mergeBlock(merge), continueBlock(cont),
1826 logger(logger) {}
1827#else
1828 ControlFlowStructurizer(Location loc, uint32_t control,
1829 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1830 Block *merge, Block *cont)
1831 : location(loc), control(control), blockMergeInfo(mergeInfo),
1832 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1833#endif
1834
1835 /// Structurizes the loop at the given `headerBlock`.
1836 ///
1837 /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1838 /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1839 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1840 /// method will also update `mergeInfo` by remapping all blocks inside to the
1841 /// newly cloned ones inside structured control flow op's regions.
1842 LogicalResult structurize();
1843
1844private:
1845 /// Creates a new spirv.mlir.selection op at the beginning of the
1846 /// `mergeBlock`.
1847 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1848
1849 /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1850 spirv::LoopOp createLoopOp(uint32_t loopControl);
1851
1852 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1853 void collectBlocksInConstruct();
1854
1855 Location location;
1856 uint32_t control;
1857
1858 spirv::BlockMergeInfoMap &blockMergeInfo;
1859
1860 Block *headerBlock;
1861 Block *mergeBlock;
1862 Block *continueBlock; // nullptr for spirv.mlir.selection
1863
1864 SetVector<Block *> constructBlocks;
1865
1866#ifndef NDEBUG
1867 /// A logger used to emit information during the deserialzation process.
1868 llvm::ScopedPrinter &logger;
1869#endif
1870};
1871} // namespace
1872
1873spirv::SelectionOp
1874ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1875 // Create a builder and set the insertion point to the beginning of the
1876 // merge block so that the newly created SelectionOp will be inserted there.
1877 OpBuilder builder(&mergeBlock->front());
1878
1879 auto control = static_cast<spirv::SelectionControl>(selectionControl);
1880 auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1881 selectionOp.addMergeBlock(builder);
1882
1883 return selectionOp;
1884}
1885
1886spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1887 // Create a builder and set the insertion point to the beginning of the
1888 // merge block so that the newly created LoopOp will be inserted there.
1889 OpBuilder builder(&mergeBlock->front());
1890
1891 auto control = static_cast<spirv::LoopControl>(loopControl);
1892 auto loopOp = builder.create<spirv::LoopOp>(location, control);
1893 loopOp.addEntryAndMergeBlock(builder);
1894
1895 return loopOp;
1896}
1897
1898void ControlFlowStructurizer::collectBlocksInConstruct() {
1899 assert(constructBlocks.empty() && "expected empty constructBlocks");
1900
1901 // Put the header block in the work list first.
1902 constructBlocks.insert(X: headerBlock);
1903
1904 // For each item in the work list, add its successors excluding the merge
1905 // block.
1906 for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1907 for (auto *successor : constructBlocks[i]->getSuccessors())
1908 if (successor != mergeBlock)
1909 constructBlocks.insert(X: successor);
1910 }
1911}
1912
1913LogicalResult ControlFlowStructurizer::structurize() {
1914 Operation *op = nullptr;
1915 bool isLoop = continueBlock != nullptr;
1916 if (isLoop) {
1917 if (auto loopOp = createLoopOp(control))
1918 op = loopOp.getOperation();
1919 } else {
1920 if (auto selectionOp = createSelectionOp(control))
1921 op = selectionOp.getOperation();
1922 }
1923 if (!op)
1924 return failure();
1925 Region &body = op->getRegion(index: 0);
1926
1927 IRMapping mapper;
1928 // All references to the old merge block should be directed to the
1929 // selection/loop merge block in the SelectionOp/LoopOp's region.
1930 mapper.map(from: mergeBlock, to: &body.back());
1931
1932 collectBlocksInConstruct();
1933
1934 // We've identified all blocks belonging to the selection/loop's region. Now
1935 // need to "move" them into the selection/loop. Instead of really moving the
1936 // blocks, in the following we copy them and remap all values and branches.
1937 // This is because:
1938 // * Inserting a block into a region requires the block not in any region
1939 // before. But selections/loops can nest so we can create selection/loop ops
1940 // in a nested manner, which means some blocks may already be in a
1941 // selection/loop region when to be moved again.
1942 // * It's much trickier to fix up the branches into and out of the loop's
1943 // region: we need to treat not-moved blocks and moved blocks differently:
1944 // Not-moved blocks jumping to the loop header block need to jump to the
1945 // merge point containing the new loop op but not the loop continue block's
1946 // back edge. Moved blocks jumping out of the loop need to jump to the
1947 // merge block inside the loop region but not other not-moved blocks.
1948 // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1949 // logic.
1950
1951 // Create a corresponding block in the SelectionOp/LoopOp's region for each
1952 // block in this loop construct.
1953 OpBuilder builder(body);
1954 for (auto *block : constructBlocks) {
1955 // Create a block and insert it before the selection/loop merge block in the
1956 // SelectionOp/LoopOp's region.
1957 auto *newBlock = builder.createBlock(insertBefore: &body.back());
1958 mapper.map(from: block, to: newBlock);
1959 LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1960 << " from block " << block << "\n");
1961 if (!isFnEntryBlock(block)) {
1962 for (BlockArgument blockArg : block->getArguments()) {
1963 auto newArg =
1964 newBlock->addArgument(type: blockArg.getType(), loc: blockArg.getLoc());
1965 mapper.map(from: blockArg, to: newArg);
1966 LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1967 << blockArg << " to " << newArg << "\n");
1968 }
1969 } else {
1970 LLVM_DEBUG(logger.startLine()
1971 << "[cf] block " << block << " is a function entry block\n");
1972 }
1973
1974 for (auto &op : *block)
1975 newBlock->push_back(op: op.clone(mapper));
1976 }
1977
1978 // Go through all ops and remap the operands.
1979 auto remapOperands = [&](Operation *op) {
1980 for (auto &operand : op->getOpOperands())
1981 if (Value mappedOp = mapper.lookupOrNull(from: operand.get()))
1982 operand.set(mappedOp);
1983 for (auto &succOp : op->getBlockOperands())
1984 if (Block *mappedOp = mapper.lookupOrNull(from: succOp.get()))
1985 succOp.set(mappedOp);
1986 };
1987 for (auto &block : body)
1988 block.walk(callback&: remapOperands);
1989
1990 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1991 // the selection/loop construct into its region. Next we need to fix the
1992 // connections between this new SelectionOp/LoopOp with existing blocks.
1993
1994 // All existing incoming branches should go to the merge block, where the
1995 // SelectionOp/LoopOp resides right now.
1996 headerBlock->replaceAllUsesWith(newValue&: mergeBlock);
1997
1998 LLVM_DEBUG({
1999 logger.startLine() << "[cf] after cloning and fixing references:\n";
2000 headerBlock->getParentOp()->print(logger.getOStream());
2001 logger.startLine() << "\n";
2002 });
2003
2004 if (isLoop) {
2005 if (!mergeBlock->args_empty()) {
2006 return mergeBlock->getParentOp()->emitError(
2007 message: "OpPhi in loop merge block unsupported");
2008 }
2009
2010 // The loop header block may have block arguments. Since now we place the
2011 // loop op inside the old merge block, we need to make sure the old merge
2012 // block has the same block argument list.
2013 for (BlockArgument blockArg : headerBlock->getArguments())
2014 mergeBlock->addArgument(type: blockArg.getType(), loc: blockArg.getLoc());
2015
2016 // If the loop header block has block arguments, make sure the spirv.Branch
2017 // op matches.
2018 SmallVector<Value, 4> blockArgs;
2019 if (!headerBlock->args_empty())
2020 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2021
2022 // The loop entry block should have a unconditional branch jumping to the
2023 // loop header block.
2024 builder.setInsertionPointToEnd(&body.front());
2025 builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
2026 ArrayRef<Value>(blockArgs));
2027 }
2028
2029 // Values defined inside the selection region that need to be yielded outside
2030 // the region.
2031 SmallVector<Value> valuesToYield;
2032 // Outside uses of values that were sunk into the selection region. Those uses
2033 // will be replaced with values returned by the SelectionOp.
2034 SmallVector<Value> outsideUses;
2035
2036 // Move block arguments of the original block (`mergeBlock`) into the merge
2037 // block inside the selection (`body.back()`). Values produced by block
2038 // arguments will be yielded by the selection region. We do not update uses or
2039 // erase original block arguments yet. It will be done later in the code.
2040 //
2041 // Code below is not executed for loops as it would interfere with the logic
2042 // above. Currently block arguments in the merge block are not supported, but
2043 // instead, the code above copies those arguments from the header block into
2044 // the merge block. As such, running the code would yield those copied
2045 // arguments that is most likely not a desired behaviour. This may need to be
2046 // revisited in the future.
2047 if (!isLoop)
2048 for (BlockArgument blockArg : mergeBlock->getArguments()) {
2049 // Create new block arguments in the last block ("merge block") of the
2050 // selection region. We create one argument for each argument in
2051 // `mergeBlock`. This new value will need to be yielded, and the original
2052 // value replaced, so add them to appropriate vectors.
2053 body.back().addArgument(type: blockArg.getType(), loc: blockArg.getLoc());
2054 valuesToYield.push_back(Elt: body.back().getArguments().back());
2055 outsideUses.push_back(Elt: blockArg);
2056 }
2057
2058 // All the blocks cloned into the SelectionOp/LoopOp's region can now be
2059 // cleaned up.
2060 LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
2061 // First we need to drop all operands' references inside all blocks. This is
2062 // needed because we can have blocks referencing SSA values from one another.
2063 for (auto *block : constructBlocks)
2064 block->dropAllReferences();
2065
2066 // All internal uses should be removed from original blocks by now, so
2067 // whatever is left is an outside use and will need to be yielded from
2068 // the newly created selection / loop region.
2069 for (Block *block : constructBlocks) {
2070 for (Operation &op : *block) {
2071 if (!op.use_empty())
2072 for (Value result : op.getResults()) {
2073 valuesToYield.push_back(Elt: mapper.lookupOrNull(from: result));
2074 outsideUses.push_back(Elt: result);
2075 }
2076 }
2077 for (BlockArgument &arg : block->getArguments()) {
2078 if (!arg.use_empty()) {
2079 valuesToYield.push_back(Elt: mapper.lookupOrNull(from: arg));
2080 outsideUses.push_back(Elt: arg);
2081 }
2082 }
2083 }
2084
2085 assert(valuesToYield.size() == outsideUses.size());
2086
2087 // If we need to yield any values from the selection / loop region we will
2088 // take care of it here.
2089 if (!valuesToYield.empty()) {
2090 LLVM_DEBUG(logger.startLine()
2091 << "[cf] yielding values from the selection / loop region\n");
2092
2093 // Update `mlir.merge` with values to be yield.
2094 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2095 Operation *merge = llvm::getSingleElement(mergeOps);
2096 assert(merge);
2097 merge->setOperands(valuesToYield);
2098
2099 // MLIR does not allow changing the number of results of an operation, so
2100 // we create a new SelectionOp / LoopOp with required list of results and
2101 // move the region from the initial SelectionOp / LoopOp. The initial
2102 // operation is then removed. Since we move the region to the new op all
2103 // links between blocks and remapping we have previously done should be
2104 // preserved.
2105 builder.setInsertionPoint(&mergeBlock->front());
2106
2107 Operation *newOp = nullptr;
2108
2109 if (isLoop)
2110 newOp = builder.create<spirv::LoopOp>(
2111 location, TypeRange(ValueRange(outsideUses)),
2112 static_cast<spirv::LoopControl>(control));
2113 else
2114 newOp = builder.create<spirv::SelectionOp>(
2115 location, TypeRange(ValueRange(outsideUses)),
2116 static_cast<spirv::SelectionControl>(control));
2117
2118 newOp->getRegion(index: 0).takeBody(other&: body);
2119
2120 // Remove initial op and swap the pointer to the newly created one.
2121 op->erase();
2122 op = newOp;
2123
2124 // Update all outside uses to use results of the SelectionOp / LoopOp and
2125 // remove block arguments from the original merge block.
2126 for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2127 outsideUses[i].replaceAllUsesWith(newValue: op->getResult(idx: i));
2128
2129 // We do not support block arguments in loop merge block. Also running this
2130 // function with loop would break some of the loop specific code above
2131 // dealing with block arguments.
2132 if (!isLoop)
2133 mergeBlock->eraseArguments(start: 0, num: mergeBlock->getNumArguments());
2134 }
2135
2136 // Check that whether some op in the to-be-erased blocks still has uses. Those
2137 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2138 // region. We cannot handle such cases given that once a value is sinked into
2139 // the SelectionOp/LoopOp's region, there is no escape for it.
2140 for (auto *block : constructBlocks) {
2141 for (Operation &op : *block)
2142 if (!op.use_empty())
2143 return op.emitOpError(message: "failed control flow structurization: value has "
2144 "uses outside of the "
2145 "enclosing selection/loop construct");
2146 for (BlockArgument &arg : block->getArguments())
2147 if (!arg.use_empty())
2148 return emitError(loc: arg.getLoc(), message: "failed control flow structurization: "
2149 "block argument has uses outside of the "
2150 "enclosing selection/loop construct");
2151 }
2152
2153 // Then erase all old blocks.
2154 for (auto *block : constructBlocks) {
2155 // We've cloned all blocks belonging to this construct into the structured
2156 // control flow op's region. Among these blocks, some may compose another
2157 // selection/loop. If so, they will be recorded within blockMergeInfo.
2158 // We need to update the pointers there to the newly remapped ones so we can
2159 // continue structurizing them later.
2160 //
2161 // We need to walk each block as constructBlocks do not include blocks
2162 // internal to ops already structured within those blocks. It is not
2163 // fully clear to me why the mergeInfo of blocks (yet to be structured)
2164 // inside already structured selections/loops get invalidated and needs
2165 // updating, however the following example code can cause a crash (depending
2166 // on the structuring order), when the most inner selection is being
2167 // structured after the outer selection and loop have been already
2168 // structured:
2169 //
2170 // spirv.mlir.for {
2171 // // ...
2172 // spirv.mlir.selection {
2173 // // ..
2174 // // A selection region that hasn't been yet structured!
2175 // // ..
2176 // }
2177 // // ...
2178 // }
2179 //
2180 // If the loop gets structured after the outer selection, but before the
2181 // inner selection. Moving the already structured selection inside the loop
2182 // will invalidate the mergeInfo of the region that is not yet structured.
2183 // Just going over constructBlocks will not check and updated header blocks
2184 // inside the already structured selection region. Walking block fixes that.
2185 //
2186 // TODO: If structuring was done in a fixed order starting with inner
2187 // most constructs this most likely not be an issue and the whole code
2188 // section could be removed. However, with the current non-deterministic
2189 // order this is not possible.
2190 //
2191 // TODO: The asserts in the following assumes input SPIR-V blob forms
2192 // correctly nested selection/loop constructs. We should relax this and
2193 // support error cases better.
2194 auto updateMergeInfo = [&](Block *block) -> WalkResult {
2195 auto it = blockMergeInfo.find(Val: block);
2196 if (it != blockMergeInfo.end()) {
2197 // Use the original location for nested selection/loop ops.
2198 Location loc = it->second.loc;
2199
2200 Block *newHeader = mapper.lookupOrNull(from: block);
2201 if (!newHeader)
2202 return emitError(loc, message: "failed control flow structurization: nested "
2203 "loop header block should be remapped!");
2204
2205 Block *newContinue = it->second.continueBlock;
2206 if (newContinue) {
2207 newContinue = mapper.lookupOrNull(from: newContinue);
2208 if (!newContinue)
2209 return emitError(loc, message: "failed control flow structurization: nested "
2210 "loop continue block should be remapped!");
2211 }
2212
2213 Block *newMerge = it->second.mergeBlock;
2214 if (Block *mappedTo = mapper.lookupOrNull(from: newMerge))
2215 newMerge = mappedTo;
2216
2217 // The iterator should be erased before adding a new entry into
2218 // blockMergeInfo to avoid iterator invalidation.
2219 blockMergeInfo.erase(I: it);
2220 blockMergeInfo.try_emplace(Key: newHeader, Args&: loc, Args&: it->second.control, Args&: newMerge,
2221 Args&: newContinue);
2222 }
2223
2224 return WalkResult::advance();
2225 };
2226
2227 if (block->walk(callback&: updateMergeInfo).wasInterrupted())
2228 return failure();
2229
2230 // The structured selection/loop's entry block does not have arguments.
2231 // If the function's header block is also part of the structured control
2232 // flow, we cannot just simply erase it because it may contain arguments
2233 // matching the function signature and used by the cloned blocks.
2234 if (isFnEntryBlock(block)) {
2235 LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2236 << " to only contain a spirv.Branch op\n");
2237 // Still keep the function entry block for the potential block arguments,
2238 // but replace all ops inside with a branch to the merge block.
2239 block->clear();
2240 builder.setInsertionPointToEnd(block);
2241 builder.create<spirv::BranchOp>(location, mergeBlock);
2242 } else {
2243 LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2244 block->erase();
2245 }
2246 }
2247
2248 LLVM_DEBUG(logger.startLine()
2249 << "[cf] after structurizing construct with header block "
2250 << headerBlock << ":\n"
2251 << *op << "\n");
2252
2253 return success();
2254}
2255
2256LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2257 LLVM_DEBUG({
2258 logger.startLine()
2259 << "//----- [phi] start wiring up block arguments -----//\n";
2260 logger.indent();
2261 });
2262
2263 OpBuilder::InsertionGuard guard(opBuilder);
2264
2265 for (const auto &info : blockPhiInfo) {
2266 Block *block = info.first.first;
2267 Block *target = info.first.second;
2268 const BlockPhiInfo &phiInfo = info.second;
2269 LLVM_DEBUG({
2270 logger.startLine() << "[phi] block " << block << "\n";
2271 logger.startLine() << "[phi] before creating block argument:\n";
2272 block->getParentOp()->print(logger.getOStream());
2273 logger.startLine() << "\n";
2274 });
2275
2276 // Set insertion point to before this block's terminator early because we
2277 // may materialize ops via getValue() call.
2278 auto *op = block->getTerminator();
2279 opBuilder.setInsertionPoint(op);
2280
2281 SmallVector<Value, 4> blockArgs;
2282 blockArgs.reserve(N: phiInfo.size());
2283 for (uint32_t valueId : phiInfo) {
2284 if (Value value = getValue(id: valueId)) {
2285 blockArgs.push_back(Elt: value);
2286 LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2287 << " id = " << valueId << "\n");
2288 } else {
2289 return emitError(loc: unknownLoc, message: "OpPhi references undefined value!");
2290 }
2291 }
2292
2293 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2294 // Replace the previous branch op with a new one with block arguments.
2295 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2296 blockArgs);
2297 branchOp.erase();
2298 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2299 assert((branchCondOp.getTrueBlock() == target ||
2300 branchCondOp.getFalseBlock() == target) &&
2301 "expected target to be either the true or false target");
2302 if (target == branchCondOp.getTrueTarget())
2303 opBuilder.create<spirv::BranchConditionalOp>(
2304 branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2305 branchCondOp.getFalseBlockArguments(),
2306 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2307 branchCondOp.getFalseTarget());
2308 else
2309 opBuilder.create<spirv::BranchConditionalOp>(
2310 branchCondOp.getLoc(), branchCondOp.getCondition(),
2311 branchCondOp.getTrueBlockArguments(), blockArgs,
2312 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2313 branchCondOp.getFalseBlock());
2314
2315 branchCondOp.erase();
2316 } else {
2317 return emitError(loc: unknownLoc, message: "unimplemented terminator for Phi creation");
2318 }
2319
2320 LLVM_DEBUG({
2321 logger.startLine() << "[phi] after creating block argument:\n";
2322 block->getParentOp()->print(logger.getOStream());
2323 logger.startLine() << "\n";
2324 });
2325 }
2326 blockPhiInfo.clear();
2327
2328 LLVM_DEBUG({
2329 logger.unindent();
2330 logger.startLine()
2331 << "//--- [phi] completed wiring up block arguments ---//\n";
2332 });
2333 return success();
2334}
2335
2336LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2337 // Create a copy, so we can modify keys in the original.
2338 BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
2339 for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2340 it != e; ++it) {
2341 auto &[block, mergeInfo] = *it;
2342
2343 // Skip processing loop regions. For loop regions continueBlock is non-null.
2344 if (mergeInfo.continueBlock)
2345 continue;
2346
2347 if (!block->mightHaveTerminator())
2348 continue;
2349
2350 Operation *terminator = block->getTerminator();
2351 assert(terminator);
2352
2353 if (!isa<spirv::BranchConditionalOp>(terminator))
2354 continue;
2355
2356 // Check if the current header block is a merge block of another construct.
2357 bool splitHeaderMergeBlock = false;
2358 for (const auto &[_, mergeInfo] : blockMergeInfo) {
2359 if (mergeInfo.mergeBlock == block)
2360 splitHeaderMergeBlock = true;
2361 }
2362
2363 // Do not split a block that only contains a conditional branch, unless it
2364 // is also a merge block of another construct - in that case we want to
2365 // split the block. We do not want two constructs to share header / merge
2366 // block.
2367 if (!llvm::hasSingleElement(C&: *block) || splitHeaderMergeBlock) {
2368 Block *newBlock = block->splitBlock(splitBeforeOp: terminator);
2369 OpBuilder builder(block, block->end());
2370 builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
2371
2372 // After splitting we need to update the map to use the new block as a
2373 // header.
2374 blockMergeInfo.erase(Val: block);
2375 blockMergeInfo.try_emplace(Key: newBlock, Args&: mergeInfo);
2376 }
2377 }
2378
2379 return success();
2380}
2381
2382LogicalResult spirv::Deserializer::structurizeControlFlow() {
2383 if (!options.enableControlFlowStructurization) {
2384 LLVM_DEBUG(
2385 {
2386 logger.startLine()
2387 << "//----- [cf] skip structurizing control flow -----//\n";
2388 logger.indent();
2389 });
2390 return success();
2391 }
2392
2393 LLVM_DEBUG({
2394 logger.startLine()
2395 << "//----- [cf] start structurizing control flow -----//\n";
2396 logger.indent();
2397 });
2398
2399 LLVM_DEBUG({
2400 logger.startLine() << "[cf] split conditional blocks\n";
2401 logger.startLine() << "\n";
2402 });
2403
2404 if (failed(Result: splitConditionalBlocks())) {
2405 return failure();
2406 }
2407
2408 // TODO: This loop is non-deterministic. Iteration order may vary between runs
2409 // for the same shader as the key to the map is a pointer. See:
2410 // https://github.com/llvm/llvm-project/issues/128547
2411 while (!blockMergeInfo.empty()) {
2412 Block *headerBlock = blockMergeInfo.begin()->first;
2413 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2414
2415 LLVM_DEBUG({
2416 logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2417 headerBlock->print(logger.getOStream());
2418 logger.startLine() << "\n";
2419 });
2420
2421 auto *mergeBlock = mergeInfo.mergeBlock;
2422 assert(mergeBlock && "merge block cannot be nullptr");
2423 if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2424 return emitError(loc: unknownLoc, message: "OpPhi in loop merge block unimplemented");
2425 LLVM_DEBUG({
2426 logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2427 mergeBlock->print(logger.getOStream());
2428 logger.startLine() << "\n";
2429 });
2430
2431 auto *continueBlock = mergeInfo.continueBlock;
2432 LLVM_DEBUG(if (continueBlock) {
2433 logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2434 continueBlock->print(logger.getOStream());
2435 logger.startLine() << "\n";
2436 });
2437 // Erase this case before calling into structurizer, who will update
2438 // blockMergeInfo.
2439 blockMergeInfo.erase(I: blockMergeInfo.begin());
2440 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2441 blockMergeInfo, headerBlock,
2442 mergeBlock, continueBlock
2443#ifndef NDEBUG
2444 ,
2445 logger
2446#endif
2447 );
2448 if (failed(Result: structurizer.structurize()))
2449 return failure();
2450 }
2451
2452 LLVM_DEBUG({
2453 logger.unindent();
2454 logger.startLine()
2455 << "//--- [cf] completed structurizing control flow ---//\n";
2456 });
2457 return success();
2458}
2459
2460//===----------------------------------------------------------------------===//
2461// Debug
2462//===----------------------------------------------------------------------===//
2463
2464Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2465 if (!debugLine)
2466 return unknownLoc;
2467
2468 auto fileName = debugInfoMap.lookup(Val: debugLine->fileID).str();
2469 if (fileName.empty())
2470 fileName = "<unknown>";
2471 return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2472 debugLine->column);
2473}
2474
2475LogicalResult
2476spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2477 // According to SPIR-V spec:
2478 // "This location information applies to the instructions physically
2479 // following this instruction, up to the first occurrence of any of the
2480 // following: the next end of block, the next OpLine instruction, or the next
2481 // OpNoLine instruction."
2482 if (operands.size() != 3)
2483 return emitError(loc: unknownLoc, message: "OpLine must have 3 operands");
2484 debugLine = DebugLine{.fileID: operands[0], .line: operands[1], .column: operands[2]};
2485 return success();
2486}
2487
2488void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2489
2490LogicalResult
2491spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2492 if (operands.size() < 2)
2493 return emitError(loc: unknownLoc, message: "OpString needs at least 2 operands");
2494
2495 if (!debugInfoMap.lookup(Val: operands[0]).empty())
2496 return emitError(loc: unknownLoc,
2497 message: "duplicate debug string found for result <id> ")
2498 << operands[0];
2499
2500 unsigned wordIndex = 1;
2501 StringRef debugString = decodeStringLiteral(words: operands, wordIndex);
2502 if (wordIndex != operands.size())
2503 return emitError(loc: unknownLoc,
2504 message: "unexpected trailing words in OpString instruction");
2505
2506 debugInfoMap[operands[0]] = debugString;
2507 return success();
2508}
2509

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp