1 | //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "Deserializer.h" |
14 | |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
19 | #include "mlir/IR/Builders.h" |
20 | #include "mlir/IR/IRMapping.h" |
21 | #include "mlir/IR/Location.h" |
22 | #include "mlir/Support/LogicalResult.h" |
23 | #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" |
24 | #include "llvm/ADT/STLExtras.h" |
25 | #include "llvm/ADT/Sequence.h" |
26 | #include "llvm/ADT/SmallVector.h" |
27 | #include "llvm/ADT/StringExtras.h" |
28 | #include "llvm/ADT/bit.h" |
29 | #include "llvm/Support/Debug.h" |
30 | #include "llvm/Support/SaveAndRestore.h" |
31 | #include "llvm/Support/raw_ostream.h" |
32 | #include <optional> |
33 | |
34 | using namespace mlir; |
35 | |
36 | #define DEBUG_TYPE "spirv-deserialization" |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // Utility Functions |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | /// Returns true if the given `block` is a function entry block. |
43 | static inline bool isFnEntryBlock(Block *block) { |
44 | return block->isEntryBlock() && |
45 | isa_and_nonnull<spirv::FuncOp>(block->getParentOp()); |
46 | } |
47 | |
48 | //===----------------------------------------------------------------------===// |
49 | // Deserializer Method Definitions |
50 | //===----------------------------------------------------------------------===// |
51 | |
52 | spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary, |
53 | MLIRContext *context) |
54 | : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), |
55 | module(createModuleOp()), opBuilder(module->getRegion()) |
56 | #ifndef NDEBUG |
57 | , |
58 | logger(llvm::dbgs()) |
59 | #endif |
60 | { |
61 | } |
62 | |
63 | LogicalResult spirv::Deserializer::deserialize() { |
64 | LLVM_DEBUG({ |
65 | logger.resetIndent(); |
66 | logger.startLine() |
67 | << "//+++---------- start deserialization ----------+++//\n" ; |
68 | }); |
69 | |
70 | if (failed(result: processHeader())) |
71 | return failure(); |
72 | |
73 | spirv::Opcode opcode = spirv::Opcode::OpNop; |
74 | ArrayRef<uint32_t> operands; |
75 | auto binarySize = binary.size(); |
76 | while (curOffset < binarySize) { |
77 | // Slice the next instruction out and populate `opcode` and `operands`. |
78 | // Internally this also updates `curOffset`. |
79 | if (failed(sliceInstruction(opcode, operands))) |
80 | return failure(); |
81 | |
82 | if (failed(processInstruction(opcode, operands))) |
83 | return failure(); |
84 | } |
85 | |
86 | assert(curOffset == binarySize && |
87 | "deserializer should never index beyond the binary end" ); |
88 | |
89 | for (auto &deferred : deferredInstructions) { |
90 | if (failed(processInstruction(deferred.first, deferred.second, false))) { |
91 | return failure(); |
92 | } |
93 | } |
94 | |
95 | attachVCETriple(); |
96 | |
97 | LLVM_DEBUG(logger.startLine() |
98 | << "//+++-------- completed deserialization --------+++//\n" ); |
99 | return success(); |
100 | } |
101 | |
102 | OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() { |
103 | return std::move(module); |
104 | } |
105 | |
106 | //===----------------------------------------------------------------------===// |
107 | // Module structure |
108 | //===----------------------------------------------------------------------===// |
109 | |
110 | OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() { |
111 | OpBuilder builder(context); |
112 | OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); |
113 | spirv::ModuleOp::build(builder, state); |
114 | return cast<spirv::ModuleOp>(Operation::create(state)); |
115 | } |
116 | |
117 | LogicalResult spirv::Deserializer::() { |
118 | if (binary.size() < spirv::kHeaderWordCount) |
119 | return emitError(loc: unknownLoc, |
120 | message: "SPIR-V binary module must have a 5-word header" ); |
121 | |
122 | if (binary[0] != spirv::kMagicNumber) |
123 | return emitError(loc: unknownLoc, message: "incorrect magic number" ); |
124 | |
125 | // Version number bytes: 0 | major number | minor number | 0 |
126 | uint32_t majorVersion = (binary[1] << 8) >> 24; |
127 | uint32_t minorVersion = (binary[1] << 16) >> 24; |
128 | if (majorVersion == 1) { |
129 | switch (minorVersion) { |
130 | #define MIN_VERSION_CASE(v) \ |
131 | case v: \ |
132 | version = spirv::Version::V_1_##v; \ |
133 | break |
134 | |
135 | MIN_VERSION_CASE(0); |
136 | MIN_VERSION_CASE(1); |
137 | MIN_VERSION_CASE(2); |
138 | MIN_VERSION_CASE(3); |
139 | MIN_VERSION_CASE(4); |
140 | MIN_VERSION_CASE(5); |
141 | #undef MIN_VERSION_CASE |
142 | default: |
143 | return emitError(loc: unknownLoc, message: "unsupported SPIR-V minor version: " ) |
144 | << minorVersion; |
145 | } |
146 | } else { |
147 | return emitError(loc: unknownLoc, message: "unsupported SPIR-V major version: " ) |
148 | << majorVersion; |
149 | } |
150 | |
151 | // TODO: generator number, bound, schema |
152 | curOffset = spirv::kHeaderWordCount; |
153 | return success(); |
154 | } |
155 | |
156 | LogicalResult |
157 | spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) { |
158 | if (operands.size() != 1) |
159 | return emitError(loc: unknownLoc, message: "OpMemoryModel must have one parameter" ); |
160 | |
161 | auto cap = spirv::symbolizeCapability(operands[0]); |
162 | if (!cap) |
163 | return emitError(loc: unknownLoc, message: "unknown capability: " ) << operands[0]; |
164 | |
165 | capabilities.insert(*cap); |
166 | return success(); |
167 | } |
168 | |
169 | LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) { |
170 | if (words.empty()) { |
171 | return emitError( |
172 | loc: unknownLoc, |
173 | message: "OpExtension must have a literal string for the extension name" ); |
174 | } |
175 | |
176 | unsigned wordIndex = 0; |
177 | StringRef extName = decodeStringLiteral(words, wordIndex); |
178 | if (wordIndex != words.size()) |
179 | return emitError(loc: unknownLoc, |
180 | message: "unexpected trailing words in OpExtension instruction" ); |
181 | auto ext = spirv::symbolizeExtension(extName); |
182 | if (!ext) |
183 | return emitError(loc: unknownLoc, message: "unknown extension: " ) << extName; |
184 | |
185 | extensions.insert(*ext); |
186 | return success(); |
187 | } |
188 | |
189 | LogicalResult |
190 | spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) { |
191 | if (words.size() < 2) { |
192 | return emitError(loc: unknownLoc, |
193 | message: "OpExtInstImport must have a result <id> and a literal " |
194 | "string for the extended instruction set name" ); |
195 | } |
196 | |
197 | unsigned wordIndex = 1; |
198 | extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex); |
199 | if (wordIndex != words.size()) { |
200 | return emitError(loc: unknownLoc, |
201 | message: "unexpected trailing words in OpExtInstImport" ); |
202 | } |
203 | return success(); |
204 | } |
205 | |
206 | void spirv::Deserializer::attachVCETriple() { |
207 | (*module)->setAttr( |
208 | spirv::ModuleOp::getVCETripleAttrName(), |
209 | spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(), |
210 | extensions.getArrayRef(), context)); |
211 | } |
212 | |
213 | LogicalResult |
214 | spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) { |
215 | if (operands.size() != 2) |
216 | return emitError(loc: unknownLoc, message: "OpMemoryModel must have two operands" ); |
217 | |
218 | (*module)->setAttr( |
219 | module->getAddressingModelAttrName(), |
220 | opBuilder.getAttr<spirv::AddressingModelAttr>( |
221 | static_cast<spirv::AddressingModel>(operands.front()))); |
222 | |
223 | (*module)->setAttr(module->getMemoryModelAttrName(), |
224 | opBuilder.getAttr<spirv::MemoryModelAttr>( |
225 | static_cast<spirv::MemoryModel>(operands.back()))); |
226 | |
227 | return success(); |
228 | } |
229 | |
230 | LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) { |
231 | // TODO: This function should also be auto-generated. For now, since only a |
232 | // few decorations are processed/handled in a meaningful manner, going with a |
233 | // manual implementation. |
234 | if (words.size() < 2) { |
235 | return emitError( |
236 | loc: unknownLoc, message: "OpDecorate must have at least result <id> and Decoration" ); |
237 | } |
238 | auto decorationName = |
239 | stringifyDecoration(static_cast<spirv::Decoration>(words[1])); |
240 | if (decorationName.empty()) { |
241 | return emitError(loc: unknownLoc, message: "invalid Decoration code : " ) << words[1]; |
242 | } |
243 | auto symbol = getSymbolDecoration(decorationName: decorationName); |
244 | switch (static_cast<spirv::Decoration>(words[1])) { |
245 | case spirv::Decoration::FPFastMathMode: |
246 | if (words.size() != 3) { |
247 | return emitError(loc: unknownLoc, message: "OpDecorate with " ) |
248 | << decorationName << " needs a single integer literal" ; |
249 | } |
250 | decorations[words[0]].set( |
251 | symbol, FPFastMathModeAttr::get(opBuilder.getContext(), |
252 | static_cast<FPFastMathMode>(words[2]))); |
253 | break; |
254 | case spirv::Decoration::DescriptorSet: |
255 | case spirv::Decoration::Binding: |
256 | if (words.size() != 3) { |
257 | return emitError(loc: unknownLoc, message: "OpDecorate with " ) |
258 | << decorationName << " needs a single integer literal" ; |
259 | } |
260 | decorations[words[0]].set( |
261 | symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); |
262 | break; |
263 | case spirv::Decoration::BuiltIn: |
264 | if (words.size() != 3) { |
265 | return emitError(loc: unknownLoc, message: "OpDecorate with " ) |
266 | << decorationName << " needs a single integer literal" ; |
267 | } |
268 | decorations[words[0]].set( |
269 | symbol, opBuilder.getStringAttr( |
270 | bytes: stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2])))); |
271 | break; |
272 | case spirv::Decoration::ArrayStride: |
273 | if (words.size() != 3) { |
274 | return emitError(loc: unknownLoc, message: "OpDecorate with " ) |
275 | << decorationName << " needs a single integer literal" ; |
276 | } |
277 | typeDecorations[words[0]] = words[2]; |
278 | break; |
279 | case spirv::Decoration::LinkageAttributes: { |
280 | if (words.size() < 4) { |
281 | return emitError(loc: unknownLoc, message: "OpDecorate with " ) |
282 | << decorationName |
283 | << " needs at least 1 string and 1 integer literal" ; |
284 | } |
285 | // LinkageAttributes has two parameters ["linkageName", linkageType] |
286 | // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import |
287 | // "linkageName" is a stringliteral encoded as uint32_t, |
288 | // hence the size of name is variable length which results in words.size() |
289 | // being variable length, words.size() = 3 + strlen(name)/4 + 1 or |
290 | // 3 + ceildiv(strlen(name), 4). |
291 | unsigned wordIndex = 2; |
292 | auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str(); |
293 | auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>( |
294 | static_cast<::mlir::spirv::LinkageType>(words[wordIndex++])); |
295 | auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>( |
296 | StringAttr::get(context, linkageName), linkageTypeAttr); |
297 | decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr)); |
298 | break; |
299 | } |
300 | case spirv::Decoration::Aliased: |
301 | case spirv::Decoration::AliasedPointer: |
302 | case spirv::Decoration::Block: |
303 | case spirv::Decoration::BufferBlock: |
304 | case spirv::Decoration::Flat: |
305 | case spirv::Decoration::NonReadable: |
306 | case spirv::Decoration::NonWritable: |
307 | case spirv::Decoration::NoPerspective: |
308 | case spirv::Decoration::NoSignedWrap: |
309 | case spirv::Decoration::NoUnsignedWrap: |
310 | case spirv::Decoration::RelaxedPrecision: |
311 | case spirv::Decoration::Restrict: |
312 | case spirv::Decoration::RestrictPointer: |
313 | case spirv::Decoration::NoContraction: |
314 | if (words.size() != 2) { |
315 | return emitError(loc: unknownLoc, message: "OpDecoration with " ) |
316 | << decorationName << "needs a single target <id>" ; |
317 | } |
318 | // Block decoration does not affect spirv.struct type, but is still stored |
319 | // for verification. |
320 | // TODO: Update StructType to contain this information since |
321 | // it is needed for many validation rules. |
322 | decorations[words[0]].set(symbol, opBuilder.getUnitAttr()); |
323 | break; |
324 | case spirv::Decoration::Location: |
325 | case spirv::Decoration::SpecId: |
326 | if (words.size() != 3) { |
327 | return emitError(loc: unknownLoc, message: "OpDecoration with " ) |
328 | << decorationName << "needs a single integer literal" ; |
329 | } |
330 | decorations[words[0]].set( |
331 | symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); |
332 | break; |
333 | default: |
334 | return emitError(loc: unknownLoc, message: "unhandled Decoration : '" ) << decorationName; |
335 | } |
336 | return success(); |
337 | } |
338 | |
339 | LogicalResult |
340 | spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) { |
341 | // The binary layout of OpMemberDecorate is different comparing to OpDecorate |
342 | if (words.size() < 3) { |
343 | return emitError(loc: unknownLoc, |
344 | message: "OpMemberDecorate must have at least 3 operands" ); |
345 | } |
346 | |
347 | auto decoration = static_cast<spirv::Decoration>(words[2]); |
348 | if (decoration == spirv::Decoration::Offset && words.size() != 4) { |
349 | return emitError(loc: unknownLoc, |
350 | message: " missing offset specification in OpMemberDecorate with " |
351 | "Offset decoration" ); |
352 | } |
353 | ArrayRef<uint32_t> decorationOperands; |
354 | if (words.size() > 3) { |
355 | decorationOperands = words.slice(N: 3); |
356 | } |
357 | memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands; |
358 | return success(); |
359 | } |
360 | |
361 | LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) { |
362 | if (words.size() < 3) { |
363 | return emitError(loc: unknownLoc, message: "OpMemberName must have at least 3 operands" ); |
364 | } |
365 | unsigned wordIndex = 2; |
366 | auto name = decodeStringLiteral(words, wordIndex); |
367 | if (wordIndex != words.size()) { |
368 | return emitError(loc: unknownLoc, |
369 | message: "unexpected trailing words in OpMemberName instruction" ); |
370 | } |
371 | memberNameMap[words[0]][words[1]] = name; |
372 | return success(); |
373 | } |
374 | |
375 | LogicalResult spirv::Deserializer::setFunctionArgAttrs( |
376 | uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) { |
377 | if (!decorations.contains(Val: argID)) { |
378 | argAttrs[argIndex] = DictionaryAttr::get(context, {}); |
379 | return success(); |
380 | } |
381 | |
382 | spirv::DecorationAttr foundDecorationAttr; |
383 | for (NamedAttribute decAttr : decorations[argID]) { |
384 | for (auto decoration : |
385 | {spirv::Decoration::Aliased, spirv::Decoration::Restrict, |
386 | spirv::Decoration::AliasedPointer, |
387 | spirv::Decoration::RestrictPointer}) { |
388 | |
389 | if (decAttr.getName() != |
390 | getSymbolDecoration(stringifyDecoration(decoration))) |
391 | continue; |
392 | |
393 | if (foundDecorationAttr) |
394 | return emitError(unknownLoc, |
395 | "more than one Aliased/Restrict decorations for " |
396 | "function argument with result <id> " ) |
397 | << argID; |
398 | |
399 | foundDecorationAttr = spirv::DecorationAttr::get(context, decoration); |
400 | break; |
401 | } |
402 | } |
403 | |
404 | if (!foundDecorationAttr) |
405 | return emitError(loc: unknownLoc, message: "unimplemented decoration support for " |
406 | "function argument with result <id> " ) |
407 | << argID; |
408 | |
409 | NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name), |
410 | foundDecorationAttr); |
411 | argAttrs[argIndex] = DictionaryAttr::get(context, attr); |
412 | return success(); |
413 | } |
414 | |
415 | LogicalResult |
416 | spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) { |
417 | if (curFunction) { |
418 | return emitError(loc: unknownLoc, message: "found function inside function" ); |
419 | } |
420 | |
421 | // Get the result type |
422 | if (operands.size() != 4) { |
423 | return emitError(loc: unknownLoc, message: "OpFunction must have 4 parameters" ); |
424 | } |
425 | Type resultType = getType(id: operands[0]); |
426 | if (!resultType) { |
427 | return emitError(loc: unknownLoc, message: "undefined result type from <id> " ) |
428 | << operands[0]; |
429 | } |
430 | |
431 | uint32_t fnID = operands[1]; |
432 | if (funcMap.count(fnID)) { |
433 | return emitError(loc: unknownLoc, message: "duplicate function definition/declaration" ); |
434 | } |
435 | |
436 | auto fnControl = spirv::symbolizeFunctionControl(operands[2]); |
437 | if (!fnControl) { |
438 | return emitError(loc: unknownLoc, message: "unknown Function Control: " ) << operands[2]; |
439 | } |
440 | |
441 | Type fnType = getType(id: operands[3]); |
442 | if (!fnType || !isa<FunctionType>(Val: fnType)) { |
443 | return emitError(loc: unknownLoc, message: "unknown function type from <id> " ) |
444 | << operands[3]; |
445 | } |
446 | auto functionType = cast<FunctionType>(fnType); |
447 | |
448 | if ((isVoidType(type: resultType) && functionType.getNumResults() != 0) || |
449 | (functionType.getNumResults() == 1 && |
450 | functionType.getResult(0) != resultType)) { |
451 | return emitError(loc: unknownLoc, message: "mismatch in function type " ) |
452 | << functionType << " and return type " << resultType << " specified" ; |
453 | } |
454 | |
455 | std::string fnName = getFunctionSymbol(id: fnID); |
456 | auto funcOp = opBuilder.create<spirv::FuncOp>( |
457 | unknownLoc, fnName, functionType, fnControl.value()); |
458 | // Processing other function attributes. |
459 | if (decorations.count(Val: fnID)) { |
460 | for (auto attr : decorations[fnID].getAttrs()) { |
461 | funcOp->setAttr(attr.getName(), attr.getValue()); |
462 | } |
463 | } |
464 | curFunction = funcMap[fnID] = funcOp; |
465 | auto *entryBlock = funcOp.addEntryBlock(); |
466 | LLVM_DEBUG({ |
467 | logger.startLine() |
468 | << "//===-------------------------------------------===//\n" ; |
469 | logger.startLine() << "[fn] name: " << fnName << "\n" ; |
470 | logger.startLine() << "[fn] type: " << fnType << "\n" ; |
471 | logger.startLine() << "[fn] ID: " << fnID << "\n" ; |
472 | logger.startLine() << "[fn] entry block: " << entryBlock << "\n" ; |
473 | logger.indent(); |
474 | }); |
475 | |
476 | SmallVector<Attribute> argAttrs; |
477 | argAttrs.resize(functionType.getNumInputs()); |
478 | |
479 | // Parse the op argument instructions |
480 | if (functionType.getNumInputs()) { |
481 | for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { |
482 | auto argType = functionType.getInput(i); |
483 | spirv::Opcode opcode = spirv::Opcode::OpNop; |
484 | ArrayRef<uint32_t> operands; |
485 | if (failed(sliceInstruction(opcode, operands, |
486 | spirv::Opcode::OpFunctionParameter))) { |
487 | return failure(); |
488 | } |
489 | if (opcode != spirv::Opcode::OpFunctionParameter) { |
490 | return emitError( |
491 | loc: unknownLoc, |
492 | message: "missing OpFunctionParameter instruction for argument " ) |
493 | << i; |
494 | } |
495 | if (operands.size() != 2) { |
496 | return emitError( |
497 | loc: unknownLoc, |
498 | message: "expected result type and result <id> for OpFunctionParameter" ); |
499 | } |
500 | auto argDefinedType = getType(id: operands[0]); |
501 | if (!argDefinedType || argDefinedType != argType) { |
502 | return emitError(loc: unknownLoc, |
503 | message: "mismatch in argument type between function type " |
504 | "definition " ) |
505 | << functionType << " and argument type definition " |
506 | << argDefinedType << " at argument " << i; |
507 | } |
508 | if (getValue(id: operands[1])) { |
509 | return emitError(loc: unknownLoc, message: "duplicate definition of result <id> " ) |
510 | << operands[1]; |
511 | } |
512 | if (failed(result: setFunctionArgAttrs(argID: operands[1], argAttrs, argIndex: i))) { |
513 | return failure(); |
514 | } |
515 | |
516 | auto argValue = funcOp.getArgument(i); |
517 | valueMap[operands[1]] = argValue; |
518 | } |
519 | } |
520 | |
521 | if (llvm::any_of(argAttrs, [](Attribute attr) { |
522 | auto argAttr = cast<DictionaryAttr>(attr); |
523 | return !argAttr.empty(); |
524 | })) |
525 | funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs)); |
526 | |
527 | // entryBlock is needed to access the arguments, Once that is done, we can |
528 | // erase the block for functions with 'Import' LinkageAttributes, since these |
529 | // are essentially function declarations, so they have no body. |
530 | auto linkageAttr = funcOp.getLinkageAttributes(); |
531 | auto hasImportLinkage = |
532 | linkageAttr && (linkageAttr.value().getLinkageType().getValue() == |
533 | spirv::LinkageType::Import); |
534 | if (hasImportLinkage) |
535 | funcOp.eraseBody(); |
536 | |
537 | // RAII guard to reset the insertion point to the module's region after |
538 | // deserializing the body of this function. |
539 | OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); |
540 | |
541 | spirv::Opcode opcode = spirv::Opcode::OpNop; |
542 | ArrayRef<uint32_t> instOperands; |
543 | |
544 | // Special handling for the entry block. We need to make sure it starts with |
545 | // an OpLabel instruction. The entry block takes the same parameters as the |
546 | // function. All other blocks do not take any parameter. We have already |
547 | // created the entry block, here we need to register it to the correct label |
548 | // <id>. |
549 | if (failed(sliceInstruction(opcode, instOperands, |
550 | spirv::Opcode::OpFunctionEnd))) { |
551 | return failure(); |
552 | } |
553 | if (opcode == spirv::Opcode::OpFunctionEnd) { |
554 | return processFunctionEnd(operands: instOperands); |
555 | } |
556 | if (opcode != spirv::Opcode::OpLabel) { |
557 | return emitError(loc: unknownLoc, message: "a basic block must start with OpLabel" ); |
558 | } |
559 | if (instOperands.size() != 1) { |
560 | return emitError(loc: unknownLoc, message: "OpLabel should only have result <id>" ); |
561 | } |
562 | blockMap[instOperands[0]] = entryBlock; |
563 | if (failed(result: processLabel(operands: instOperands))) { |
564 | return failure(); |
565 | } |
566 | |
567 | // Then process all the other instructions in the function until we hit |
568 | // OpFunctionEnd. |
569 | while (succeeded(sliceInstruction(opcode, instOperands, |
570 | spirv::Opcode::OpFunctionEnd)) && |
571 | opcode != spirv::Opcode::OpFunctionEnd) { |
572 | if (failed(processInstruction(opcode, instOperands))) { |
573 | return failure(); |
574 | } |
575 | } |
576 | if (opcode != spirv::Opcode::OpFunctionEnd) { |
577 | return failure(); |
578 | } |
579 | |
580 | return processFunctionEnd(operands: instOperands); |
581 | } |
582 | |
583 | LogicalResult |
584 | spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) { |
585 | // Process OpFunctionEnd. |
586 | if (!operands.empty()) { |
587 | return emitError(loc: unknownLoc, message: "unexpected operands for OpFunctionEnd" ); |
588 | } |
589 | |
590 | // Wire up block arguments from OpPhi instructions. |
591 | // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop |
592 | // ops. |
593 | if (failed(result: wireUpBlockArgument()) || failed(result: structurizeControlFlow())) { |
594 | return failure(); |
595 | } |
596 | |
597 | curBlock = nullptr; |
598 | curFunction = std::nullopt; |
599 | |
600 | LLVM_DEBUG({ |
601 | logger.unindent(); |
602 | logger.startLine() |
603 | << "//===-------------------------------------------===//\n" ; |
604 | }); |
605 | return success(); |
606 | } |
607 | |
608 | std::optional<std::pair<Attribute, Type>> |
609 | spirv::Deserializer::getConstant(uint32_t id) { |
610 | auto constIt = constantMap.find(Val: id); |
611 | if (constIt == constantMap.end()) |
612 | return std::nullopt; |
613 | return constIt->getSecond(); |
614 | } |
615 | |
616 | std::optional<spirv::SpecConstOperationMaterializationInfo> |
617 | spirv::Deserializer::getSpecConstantOperation(uint32_t id) { |
618 | auto constIt = specConstOperationMap.find(Val: id); |
619 | if (constIt == specConstOperationMap.end()) |
620 | return std::nullopt; |
621 | return constIt->getSecond(); |
622 | } |
623 | |
624 | std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) { |
625 | auto funcName = nameMap.lookup(Val: id).str(); |
626 | if (funcName.empty()) { |
627 | funcName = "spirv_fn_" + std::to_string(val: id); |
628 | } |
629 | return funcName; |
630 | } |
631 | |
632 | std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) { |
633 | auto constName = nameMap.lookup(Val: id).str(); |
634 | if (constName.empty()) { |
635 | constName = "spirv_spec_const_" + std::to_string(val: id); |
636 | } |
637 | return constName; |
638 | } |
639 | |
640 | spirv::SpecConstantOp |
641 | spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID, |
642 | TypedAttr defaultValue) { |
643 | auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(id: resultID)); |
644 | auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, |
645 | defaultValue); |
646 | if (decorations.count(Val: resultID)) { |
647 | for (auto attr : decorations[resultID].getAttrs()) |
648 | op->setAttr(attr.getName(), attr.getValue()); |
649 | } |
650 | specConstMap[resultID] = op; |
651 | return op; |
652 | } |
653 | |
654 | LogicalResult |
655 | spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) { |
656 | unsigned wordIndex = 0; |
657 | if (operands.size() < 3) { |
658 | return emitError( |
659 | loc: unknownLoc, |
660 | message: "OpVariable needs at least 3 operands, type, <id> and storage class" ); |
661 | } |
662 | |
663 | // Result Type. |
664 | auto type = getType(id: operands[wordIndex]); |
665 | if (!type) { |
666 | return emitError(loc: unknownLoc, message: "unknown result type <id> : " ) |
667 | << operands[wordIndex]; |
668 | } |
669 | auto ptrType = dyn_cast<spirv::PointerType>(Val&: type); |
670 | if (!ptrType) { |
671 | return emitError(loc: unknownLoc, |
672 | message: "expected a result type <id> to be a spirv.ptr, found : " ) |
673 | << type; |
674 | } |
675 | wordIndex++; |
676 | |
677 | // Result <id>. |
678 | auto variableID = operands[wordIndex]; |
679 | auto variableName = nameMap.lookup(Val: variableID).str(); |
680 | if (variableName.empty()) { |
681 | variableName = "spirv_var_" + std::to_string(val: variableID); |
682 | } |
683 | wordIndex++; |
684 | |
685 | // Storage class. |
686 | auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]); |
687 | if (ptrType.getStorageClass() != storageClass) { |
688 | return emitError(loc: unknownLoc, message: "mismatch in storage class of pointer type " ) |
689 | << type << " and that specified in OpVariable instruction : " |
690 | << stringifyStorageClass(storageClass); |
691 | } |
692 | wordIndex++; |
693 | |
694 | // Initializer. |
695 | FlatSymbolRefAttr initializer = nullptr; |
696 | |
697 | if (wordIndex < operands.size()) { |
698 | Operation *op = nullptr; |
699 | |
700 | if (auto initOp = getGlobalVariable(operands[wordIndex])) |
701 | op = initOp; |
702 | else if (auto initOp = getSpecConstant(operands[wordIndex])) |
703 | op = initOp; |
704 | else if (auto initOp = getSpecConstantComposite(operands[wordIndex])) |
705 | op = initOp; |
706 | else |
707 | return emitError(loc: unknownLoc, message: "unknown <id> " ) |
708 | << operands[wordIndex] << "used as initializer" ; |
709 | |
710 | initializer = SymbolRefAttr::get(op); |
711 | wordIndex++; |
712 | } |
713 | if (wordIndex != operands.size()) { |
714 | return emitError(loc: unknownLoc, |
715 | message: "found more operands than expected when deserializing " |
716 | "OpVariable instruction, only " ) |
717 | << wordIndex << " of " << operands.size() << " processed" ; |
718 | } |
719 | auto loc = createFileLineColLoc(opBuilder); |
720 | auto varOp = opBuilder.create<spirv::GlobalVariableOp>( |
721 | loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName), |
722 | initializer); |
723 | |
724 | // Decorations. |
725 | if (decorations.count(Val: variableID)) { |
726 | for (auto attr : decorations[variableID].getAttrs()) |
727 | varOp->setAttr(attr.getName(), attr.getValue()); |
728 | } |
729 | globalVariableMap[variableID] = varOp; |
730 | return success(); |
731 | } |
732 | |
733 | IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) { |
734 | auto constInfo = getConstant(id); |
735 | if (!constInfo) { |
736 | return nullptr; |
737 | } |
738 | return dyn_cast<IntegerAttr>(constInfo->first); |
739 | } |
740 | |
741 | LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) { |
742 | if (operands.size() < 2) { |
743 | return emitError(loc: unknownLoc, message: "OpName needs at least 2 operands" ); |
744 | } |
745 | if (!nameMap.lookup(Val: operands[0]).empty()) { |
746 | return emitError(loc: unknownLoc, message: "duplicate name found for result <id> " ) |
747 | << operands[0]; |
748 | } |
749 | unsigned wordIndex = 1; |
750 | StringRef name = decodeStringLiteral(words: operands, wordIndex); |
751 | if (wordIndex != operands.size()) { |
752 | return emitError(loc: unknownLoc, |
753 | message: "unexpected trailing words in OpName instruction" ); |
754 | } |
755 | nameMap[operands[0]] = name; |
756 | return success(); |
757 | } |
758 | |
759 | //===----------------------------------------------------------------------===// |
760 | // Type |
761 | //===----------------------------------------------------------------------===// |
762 | |
763 | LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, |
764 | ArrayRef<uint32_t> operands) { |
765 | if (operands.empty()) { |
766 | return emitError(unknownLoc, "type instruction with opcode " ) |
767 | << spirv::stringifyOpcode(opcode) << " needs at least one <id>" ; |
768 | } |
769 | |
770 | /// TODO: Types might be forward declared in some instructions and need to be |
771 | /// handled appropriately. |
772 | if (typeMap.count(Val: operands[0])) { |
773 | return emitError(loc: unknownLoc, message: "duplicate definition for result <id> " ) |
774 | << operands[0]; |
775 | } |
776 | |
777 | switch (opcode) { |
778 | case spirv::Opcode::OpTypeVoid: |
779 | if (operands.size() != 1) |
780 | return emitError(loc: unknownLoc, message: "OpTypeVoid must have no parameters" ); |
781 | typeMap[operands[0]] = opBuilder.getNoneType(); |
782 | break; |
783 | case spirv::Opcode::OpTypeBool: |
784 | if (operands.size() != 1) |
785 | return emitError(loc: unknownLoc, message: "OpTypeBool must have no parameters" ); |
786 | typeMap[operands[0]] = opBuilder.getI1Type(); |
787 | break; |
788 | case spirv::Opcode::OpTypeInt: { |
789 | if (operands.size() != 3) |
790 | return emitError( |
791 | loc: unknownLoc, message: "OpTypeInt must have bitwidth and signedness parameters" ); |
792 | |
793 | // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics |
794 | // to preserve or validate. |
795 | // 0 indicates unsigned, or no signedness semantics |
796 | // 1 indicates signed semantics." |
797 | // |
798 | // So we cannot differentiate signless and unsigned integers; always use |
799 | // signless semantics for such cases. |
800 | auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed |
801 | : IntegerType::SignednessSemantics::Signless; |
802 | typeMap[operands[0]] = IntegerType::get(context, operands[1], sign); |
803 | } break; |
804 | case spirv::Opcode::OpTypeFloat: { |
805 | if (operands.size() != 2) |
806 | return emitError(loc: unknownLoc, message: "OpTypeFloat must have bitwidth parameter" ); |
807 | |
808 | Type floatTy; |
809 | switch (operands[1]) { |
810 | case 16: |
811 | floatTy = opBuilder.getF16Type(); |
812 | break; |
813 | case 32: |
814 | floatTy = opBuilder.getF32Type(); |
815 | break; |
816 | case 64: |
817 | floatTy = opBuilder.getF64Type(); |
818 | break; |
819 | default: |
820 | return emitError(loc: unknownLoc, message: "unsupported OpTypeFloat bitwidth: " ) |
821 | << operands[1]; |
822 | } |
823 | typeMap[operands[0]] = floatTy; |
824 | } break; |
825 | case spirv::Opcode::OpTypeVector: { |
826 | if (operands.size() != 3) { |
827 | return emitError( |
828 | loc: unknownLoc, |
829 | message: "OpTypeVector must have element type and count parameters" ); |
830 | } |
831 | Type elementTy = getType(id: operands[1]); |
832 | if (!elementTy) { |
833 | return emitError(loc: unknownLoc, message: "OpTypeVector references undefined <id> " ) |
834 | << operands[1]; |
835 | } |
836 | typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy); |
837 | } break; |
838 | case spirv::Opcode::OpTypePointer: { |
839 | return processOpTypePointer(operands); |
840 | } break; |
841 | case spirv::Opcode::OpTypeArray: |
842 | return processArrayType(operands); |
843 | case spirv::Opcode::OpTypeCooperativeMatrixKHR: |
844 | return processCooperativeMatrixTypeKHR(operands); |
845 | case spirv::Opcode::OpTypeFunction: |
846 | return processFunctionType(operands); |
847 | case spirv::Opcode::OpTypeJointMatrixINTEL: |
848 | return processJointMatrixType(operands); |
849 | case spirv::Opcode::OpTypeImage: |
850 | return processImageType(operands); |
851 | case spirv::Opcode::OpTypeSampledImage: |
852 | return processSampledImageType(operands); |
853 | case spirv::Opcode::OpTypeRuntimeArray: |
854 | return processRuntimeArrayType(operands); |
855 | case spirv::Opcode::OpTypeStruct: |
856 | return processStructType(operands); |
857 | case spirv::Opcode::OpTypeMatrix: |
858 | return processMatrixType(operands); |
859 | default: |
860 | return emitError(loc: unknownLoc, message: "unhandled type instruction" ); |
861 | } |
862 | return success(); |
863 | } |
864 | |
865 | LogicalResult |
866 | spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) { |
867 | if (operands.size() != 3) |
868 | return emitError(loc: unknownLoc, message: "OpTypePointer must have two parameters" ); |
869 | |
870 | auto pointeeType = getType(id: operands[2]); |
871 | if (!pointeeType) |
872 | return emitError(loc: unknownLoc, message: "unknown OpTypePointer pointee type <id> " ) |
873 | << operands[2]; |
874 | |
875 | uint32_t typePointerID = operands[0]; |
876 | auto storageClass = static_cast<spirv::StorageClass>(operands[1]); |
877 | typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass); |
878 | |
879 | for (auto *deferredStructIt = std::begin(cont&: deferredStructTypesInfos); |
880 | deferredStructIt != std::end(cont&: deferredStructTypesInfos);) { |
881 | for (auto *unresolvedMemberIt = |
882 | std::begin(cont&: deferredStructIt->unresolvedMemberTypes); |
883 | unresolvedMemberIt != |
884 | std::end(cont&: deferredStructIt->unresolvedMemberTypes);) { |
885 | if (unresolvedMemberIt->first == typePointerID) { |
886 | // The newly constructed pointer type can resolve one of the |
887 | // deferred struct type members; update the memberTypes list and |
888 | // clean the unresolvedMemberTypes list accordingly. |
889 | deferredStructIt->memberTypes[unresolvedMemberIt->second] = |
890 | typeMap[typePointerID]; |
891 | unresolvedMemberIt = |
892 | deferredStructIt->unresolvedMemberTypes.erase(CI: unresolvedMemberIt); |
893 | } else { |
894 | ++unresolvedMemberIt; |
895 | } |
896 | } |
897 | |
898 | if (deferredStructIt->unresolvedMemberTypes.empty()) { |
899 | // All deferred struct type members are now resolved, set the struct body. |
900 | auto structType = deferredStructIt->deferredStructType; |
901 | |
902 | assert(structType && "expected a spirv::StructType" ); |
903 | assert(structType.isIdentified() && "expected an indentified struct" ); |
904 | |
905 | if (failed(result: structType.trySetBody( |
906 | memberTypes: deferredStructIt->memberTypes, offsetInfo: deferredStructIt->offsetInfo, |
907 | memberDecorations: deferredStructIt->memberDecorationsInfo))) |
908 | return failure(); |
909 | |
910 | deferredStructIt = deferredStructTypesInfos.erase(CI: deferredStructIt); |
911 | } else { |
912 | ++deferredStructIt; |
913 | } |
914 | } |
915 | |
916 | return success(); |
917 | } |
918 | |
919 | LogicalResult |
920 | spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) { |
921 | if (operands.size() != 3) { |
922 | return emitError(loc: unknownLoc, |
923 | message: "OpTypeArray must have element type and count parameters" ); |
924 | } |
925 | |
926 | Type elementTy = getType(id: operands[1]); |
927 | if (!elementTy) { |
928 | return emitError(loc: unknownLoc, message: "OpTypeArray references undefined <id> " ) |
929 | << operands[1]; |
930 | } |
931 | |
932 | unsigned count = 0; |
933 | // TODO: The count can also come frome a specialization constant. |
934 | auto countInfo = getConstant(id: operands[2]); |
935 | if (!countInfo) { |
936 | return emitError(loc: unknownLoc, message: "OpTypeArray count <id> " ) |
937 | << operands[2] << "can only come from normal constant right now" ; |
938 | } |
939 | |
940 | if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) { |
941 | count = intVal.getValue().getZExtValue(); |
942 | } else { |
943 | return emitError(loc: unknownLoc, message: "OpTypeArray count must come from a " |
944 | "scalar integer constant instruction" ); |
945 | } |
946 | |
947 | typeMap[operands[0]] = spirv::ArrayType::get( |
948 | elementType: elementTy, elementCount: count, stride: typeDecorations.lookup(Val: operands[0])); |
949 | return success(); |
950 | } |
951 | |
952 | LogicalResult |
953 | spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) { |
954 | assert(!operands.empty() && "No operands for processing function type" ); |
955 | if (operands.size() == 1) { |
956 | return emitError(loc: unknownLoc, message: "missing return type for OpTypeFunction" ); |
957 | } |
958 | auto returnType = getType(id: operands[1]); |
959 | if (!returnType) { |
960 | return emitError(loc: unknownLoc, message: "unknown return type in OpTypeFunction" ); |
961 | } |
962 | SmallVector<Type, 1> argTypes; |
963 | for (size_t i = 2, e = operands.size(); i < e; ++i) { |
964 | auto ty = getType(id: operands[i]); |
965 | if (!ty) { |
966 | return emitError(loc: unknownLoc, message: "unknown argument type in OpTypeFunction" ); |
967 | } |
968 | argTypes.push_back(Elt: ty); |
969 | } |
970 | ArrayRef<Type> returnTypes; |
971 | if (!isVoidType(type: returnType)) { |
972 | returnTypes = llvm::ArrayRef(returnType); |
973 | } |
974 | typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes); |
975 | return success(); |
976 | } |
977 | |
978 | LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR( |
979 | ArrayRef<uint32_t> operands) { |
980 | if (operands.size() != 6) { |
981 | return emitError(loc: unknownLoc, |
982 | message: "OpTypeCooperativeMatrixKHR must have element type, " |
983 | "scope, row and column parameters, and use" ); |
984 | } |
985 | |
986 | Type elementTy = getType(id: operands[1]); |
987 | if (!elementTy) { |
988 | return emitError(loc: unknownLoc, |
989 | message: "OpTypeCooperativeMatrixKHR references undefined <id> " ) |
990 | << operands[1]; |
991 | } |
992 | |
993 | std::optional<spirv::Scope> scope = |
994 | spirv::symbolizeScope(getConstantInt(operands[2]).getInt()); |
995 | if (!scope) { |
996 | return emitError( |
997 | loc: unknownLoc, |
998 | message: "OpTypeCooperativeMatrixKHR references undefined scope <id> " ) |
999 | << operands[2]; |
1000 | } |
1001 | |
1002 | unsigned rows = getConstantInt(operands[3]).getInt(); |
1003 | unsigned columns = getConstantInt(operands[4]).getInt(); |
1004 | |
1005 | std::optional<spirv::CooperativeMatrixUseKHR> use = |
1006 | spirv::symbolizeCooperativeMatrixUseKHR( |
1007 | getConstantInt(operands[5]).getInt()); |
1008 | if (!use) { |
1009 | return emitError( |
1010 | loc: unknownLoc, |
1011 | message: "OpTypeCooperativeMatrixKHR references undefined use <id> " ) |
1012 | << operands[5]; |
1013 | } |
1014 | |
1015 | typeMap[operands[0]] = |
1016 | spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use); |
1017 | return success(); |
1018 | } |
1019 | |
1020 | LogicalResult |
1021 | spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) { |
1022 | if (operands.size() != 6) { |
1023 | return emitError(loc: unknownLoc, message: "OpTypeJointMatrix must have element " |
1024 | "type and row x column parameters" ); |
1025 | } |
1026 | |
1027 | Type elementTy = getType(id: operands[1]); |
1028 | if (!elementTy) { |
1029 | return emitError(loc: unknownLoc, message: "OpTypeJointMatrix references undefined <id> " ) |
1030 | << operands[1]; |
1031 | } |
1032 | |
1033 | auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt()); |
1034 | if (!scope) { |
1035 | return emitError(loc: unknownLoc, |
1036 | message: "OpTypeJointMatrix references undefined scope <id> " ) |
1037 | << operands[5]; |
1038 | } |
1039 | auto matrixLayout = |
1040 | spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt()); |
1041 | if (!matrixLayout) { |
1042 | return emitError(loc: unknownLoc, |
1043 | message: "OpTypeJointMatrix references undefined scope <id> " ) |
1044 | << operands[4]; |
1045 | } |
1046 | unsigned rows = getConstantInt(operands[2]).getInt(); |
1047 | unsigned columns = getConstantInt(operands[3]).getInt(); |
1048 | |
1049 | typeMap[operands[0]] = spirv::JointMatrixINTELType::get( |
1050 | elementTy, scope.value(), rows, columns, matrixLayout.value()); |
1051 | return success(); |
1052 | } |
1053 | |
1054 | LogicalResult |
1055 | spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) { |
1056 | if (operands.size() != 2) { |
1057 | return emitError(loc: unknownLoc, message: "OpTypeRuntimeArray must have two operands" ); |
1058 | } |
1059 | Type memberType = getType(id: operands[1]); |
1060 | if (!memberType) { |
1061 | return emitError(loc: unknownLoc, |
1062 | message: "OpTypeRuntimeArray references undefined <id> " ) |
1063 | << operands[1]; |
1064 | } |
1065 | typeMap[operands[0]] = spirv::RuntimeArrayType::get( |
1066 | elementType: memberType, stride: typeDecorations.lookup(Val: operands[0])); |
1067 | return success(); |
1068 | } |
1069 | |
1070 | LogicalResult |
1071 | spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) { |
1072 | // TODO: Find a way to handle identified structs when debug info is stripped. |
1073 | |
1074 | if (operands.empty()) { |
1075 | return emitError(loc: unknownLoc, message: "OpTypeStruct must have at least result <id>" ); |
1076 | } |
1077 | |
1078 | if (operands.size() == 1) { |
1079 | // Handle empty struct. |
1080 | typeMap[operands[0]] = |
1081 | spirv::StructType::getEmpty(context, identifier: nameMap.lookup(Val: operands[0]).str()); |
1082 | return success(); |
1083 | } |
1084 | |
1085 | // First element is operand ID, second element is member index in the struct. |
1086 | SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes; |
1087 | SmallVector<Type, 4> memberTypes; |
1088 | |
1089 | for (auto op : llvm::drop_begin(RangeOrContainer&: operands, N: 1)) { |
1090 | Type memberType = getType(id: op); |
1091 | bool typeForwardPtr = (typeForwardPointerIDs.count(key: op) != 0); |
1092 | |
1093 | if (!memberType && !typeForwardPtr) |
1094 | return emitError(loc: unknownLoc, message: "OpTypeStruct references undefined <id> " ) |
1095 | << op; |
1096 | |
1097 | if (!memberType) |
1098 | unresolvedMemberTypes.emplace_back(Args&: op, Args: memberTypes.size()); |
1099 | |
1100 | memberTypes.push_back(Elt: memberType); |
1101 | } |
1102 | |
1103 | SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo; |
1104 | SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo; |
1105 | if (memberDecorationMap.count(operands[0])) { |
1106 | auto &allMemberDecorations = memberDecorationMap[operands[0]]; |
1107 | for (auto memberIndex : llvm::seq<uint32_t>(Begin: 0, End: memberTypes.size())) { |
1108 | if (allMemberDecorations.count(memberIndex)) { |
1109 | for (auto &memberDecoration : allMemberDecorations[memberIndex]) { |
1110 | // Check for offset. |
1111 | if (memberDecoration.first == spirv::Decoration::Offset) { |
1112 | // If offset info is empty, resize to the number of members; |
1113 | if (offsetInfo.empty()) { |
1114 | offsetInfo.resize(memberTypes.size()); |
1115 | } |
1116 | offsetInfo[memberIndex] = memberDecoration.second[0]; |
1117 | } else { |
1118 | if (!memberDecoration.second.empty()) { |
1119 | memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1, |
1120 | memberDecoration.first, |
1121 | memberDecoration.second[0]); |
1122 | } else { |
1123 | memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0, |
1124 | memberDecoration.first, 0); |
1125 | } |
1126 | } |
1127 | } |
1128 | } |
1129 | } |
1130 | } |
1131 | |
1132 | uint32_t structID = operands[0]; |
1133 | std::string structIdentifier = nameMap.lookup(Val: structID).str(); |
1134 | |
1135 | if (structIdentifier.empty()) { |
1136 | assert(unresolvedMemberTypes.empty() && |
1137 | "didn't expect unresolved member types" ); |
1138 | typeMap[structID] = |
1139 | spirv::StructType::get(memberTypes, offsetInfo, memberDecorations: memberDecorationsInfo); |
1140 | } else { |
1141 | auto structTy = spirv::StructType::getIdentified(context, identifier: structIdentifier); |
1142 | typeMap[structID] = structTy; |
1143 | |
1144 | if (!unresolvedMemberTypes.empty()) |
1145 | deferredStructTypesInfos.push_back(Elt: {.deferredStructType: structTy, .unresolvedMemberTypes: unresolvedMemberTypes, |
1146 | .memberTypes: memberTypes, .offsetInfo: offsetInfo, |
1147 | .memberDecorationsInfo: memberDecorationsInfo}); |
1148 | else if (failed(result: structTy.trySetBody(memberTypes, offsetInfo, |
1149 | memberDecorations: memberDecorationsInfo))) |
1150 | return failure(); |
1151 | } |
1152 | |
1153 | // TODO: Update StructType to have member name as attribute as |
1154 | // well. |
1155 | return success(); |
1156 | } |
1157 | |
1158 | LogicalResult |
1159 | spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) { |
1160 | if (operands.size() != 3) { |
1161 | // Three operands are needed: result_id, column_type, and column_count |
1162 | return emitError(loc: unknownLoc, message: "OpTypeMatrix must have 3 operands" |
1163 | " (result_id, column_type, and column_count)" ); |
1164 | } |
1165 | // Matrix columns must be of vector type |
1166 | Type elementTy = getType(id: operands[1]); |
1167 | if (!elementTy) { |
1168 | return emitError(loc: unknownLoc, |
1169 | message: "OpTypeMatrix references undefined column type." ) |
1170 | << operands[1]; |
1171 | } |
1172 | |
1173 | uint32_t colsCount = operands[2]; |
1174 | typeMap[operands[0]] = spirv::MatrixType::get(columnType: elementTy, columnCount: colsCount); |
1175 | return success(); |
1176 | } |
1177 | |
1178 | LogicalResult |
1179 | spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) { |
1180 | if (operands.size() != 2) |
1181 | return emitError(loc: unknownLoc, |
1182 | message: "OpTypeForwardPointer instruction must have two operands" ); |
1183 | |
1184 | typeForwardPointerIDs.insert(X: operands[0]); |
1185 | // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer |
1186 | // instruction that defines the actual type. |
1187 | |
1188 | return success(); |
1189 | } |
1190 | |
1191 | LogicalResult |
1192 | spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) { |
1193 | // TODO: Add support for Access Qualifier. |
1194 | if (operands.size() != 8) |
1195 | return emitError( |
1196 | loc: unknownLoc, |
1197 | message: "OpTypeImage with non-eight operands are not supported yet" ); |
1198 | |
1199 | Type elementTy = getType(id: operands[1]); |
1200 | if (!elementTy) |
1201 | return emitError(loc: unknownLoc, message: "OpTypeImage references undefined <id>: " ) |
1202 | << operands[1]; |
1203 | |
1204 | auto dim = spirv::symbolizeDim(operands[2]); |
1205 | if (!dim) |
1206 | return emitError(loc: unknownLoc, message: "unknown Dim for OpTypeImage: " ) |
1207 | << operands[2]; |
1208 | |
1209 | auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]); |
1210 | if (!depthInfo) |
1211 | return emitError(loc: unknownLoc, message: "unknown Depth for OpTypeImage: " ) |
1212 | << operands[3]; |
1213 | |
1214 | auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]); |
1215 | if (!arrayedInfo) |
1216 | return emitError(loc: unknownLoc, message: "unknown Arrayed for OpTypeImage: " ) |
1217 | << operands[4]; |
1218 | |
1219 | auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]); |
1220 | if (!samplingInfo) |
1221 | return emitError(loc: unknownLoc, message: "unknown MS for OpTypeImage: " ) << operands[5]; |
1222 | |
1223 | auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]); |
1224 | if (!samplerUseInfo) |
1225 | return emitError(loc: unknownLoc, message: "unknown Sampled for OpTypeImage: " ) |
1226 | << operands[6]; |
1227 | |
1228 | auto format = spirv::symbolizeImageFormat(operands[7]); |
1229 | if (!format) |
1230 | return emitError(loc: unknownLoc, message: "unknown Format for OpTypeImage: " ) |
1231 | << operands[7]; |
1232 | |
1233 | typeMap[operands[0]] = spirv::ImageType::get( |
1234 | elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(), |
1235 | samplingInfo.value(), samplerUseInfo.value(), format.value()); |
1236 | return success(); |
1237 | } |
1238 | |
1239 | LogicalResult |
1240 | spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) { |
1241 | if (operands.size() != 2) |
1242 | return emitError(loc: unknownLoc, message: "OpTypeSampledImage must have two operands" ); |
1243 | |
1244 | Type elementTy = getType(id: operands[1]); |
1245 | if (!elementTy) |
1246 | return emitError(loc: unknownLoc, |
1247 | message: "OpTypeSampledImage references undefined <id>: " ) |
1248 | << operands[1]; |
1249 | |
1250 | typeMap[operands[0]] = spirv::SampledImageType::get(imageType: elementTy); |
1251 | return success(); |
1252 | } |
1253 | |
1254 | //===----------------------------------------------------------------------===// |
1255 | // Constant |
1256 | //===----------------------------------------------------------------------===// |
1257 | |
1258 | LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands, |
1259 | bool isSpec) { |
1260 | StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant" ; |
1261 | |
1262 | if (operands.size() < 2) { |
1263 | return emitError(loc: unknownLoc) |
1264 | << opname << " must have type <id> and result <id>" ; |
1265 | } |
1266 | if (operands.size() < 3) { |
1267 | return emitError(loc: unknownLoc) |
1268 | << opname << " must have at least 1 more parameter" ; |
1269 | } |
1270 | |
1271 | Type resultType = getType(id: operands[0]); |
1272 | if (!resultType) { |
1273 | return emitError(loc: unknownLoc, message: "undefined result type from <id> " ) |
1274 | << operands[0]; |
1275 | } |
1276 | |
1277 | auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult { |
1278 | if (bitwidth == 64) { |
1279 | if (operands.size() == 4) { |
1280 | return success(); |
1281 | } |
1282 | return emitError(loc: unknownLoc) |
1283 | << opname << " should have 2 parameters for 64-bit values" ; |
1284 | } |
1285 | if (bitwidth <= 32) { |
1286 | if (operands.size() == 3) { |
1287 | return success(); |
1288 | } |
1289 | |
1290 | return emitError(loc: unknownLoc) |
1291 | << opname |
1292 | << " should have 1 parameter for values with no more than 32 bits" ; |
1293 | } |
1294 | return emitError(loc: unknownLoc, message: "unsupported OpConstant bitwidth: " ) |
1295 | << bitwidth; |
1296 | }; |
1297 | |
1298 | auto resultID = operands[1]; |
1299 | |
1300 | if (auto intType = dyn_cast<IntegerType>(resultType)) { |
1301 | auto bitwidth = intType.getWidth(); |
1302 | if (failed(checkOperandSizeForBitwidth(bitwidth))) { |
1303 | return failure(); |
1304 | } |
1305 | |
1306 | APInt value; |
1307 | if (bitwidth == 64) { |
1308 | // 64-bit integers are represented with two SPIR-V words. According to |
1309 | // SPIR-V spec: "When the type’s bit width is larger than one word, the |
1310 | // literal’s low-order words appear first." |
1311 | struct DoubleWord { |
1312 | uint32_t word1; |
1313 | uint32_t word2; |
1314 | } words = {.word1: operands[2], .word2: operands[3]}; |
1315 | value = APInt(64, llvm::bit_cast<uint64_t>(from: words), /*isSigned=*/true); |
1316 | } else if (bitwidth <= 32) { |
1317 | value = APInt(bitwidth, operands[2], /*isSigned=*/true); |
1318 | } |
1319 | |
1320 | auto attr = opBuilder.getIntegerAttr(intType, value); |
1321 | |
1322 | if (isSpec) { |
1323 | createSpecConstant(unknownLoc, resultID, attr); |
1324 | } else { |
1325 | // For normal constants, we just record the attribute (and its type) for |
1326 | // later materialization at use sites. |
1327 | constantMap.try_emplace(resultID, attr, intType); |
1328 | } |
1329 | |
1330 | return success(); |
1331 | } |
1332 | |
1333 | if (auto floatType = dyn_cast<FloatType>(Val&: resultType)) { |
1334 | auto bitwidth = floatType.getWidth(); |
1335 | if (failed(result: checkOperandSizeForBitwidth(bitwidth))) { |
1336 | return failure(); |
1337 | } |
1338 | |
1339 | APFloat value(0.f); |
1340 | if (floatType.isF64()) { |
1341 | // Double values are represented with two SPIR-V words. According to |
1342 | // SPIR-V spec: "When the type’s bit width is larger than one word, the |
1343 | // literal’s low-order words appear first." |
1344 | struct DoubleWord { |
1345 | uint32_t word1; |
1346 | uint32_t word2; |
1347 | } words = {.word1: operands[2], .word2: operands[3]}; |
1348 | value = APFloat(llvm::bit_cast<double>(from: words)); |
1349 | } else if (floatType.isF32()) { |
1350 | value = APFloat(llvm::bit_cast<float>(from: operands[2])); |
1351 | } else if (floatType.isF16()) { |
1352 | APInt data(16, operands[2]); |
1353 | value = APFloat(APFloat::IEEEhalf(), data); |
1354 | } |
1355 | |
1356 | auto attr = opBuilder.getFloatAttr(floatType, value); |
1357 | if (isSpec) { |
1358 | createSpecConstant(unknownLoc, resultID, attr); |
1359 | } else { |
1360 | // For normal constants, we just record the attribute (and its type) for |
1361 | // later materialization at use sites. |
1362 | constantMap.try_emplace(resultID, attr, floatType); |
1363 | } |
1364 | |
1365 | return success(); |
1366 | } |
1367 | |
1368 | return emitError(loc: unknownLoc, message: "OpConstant can only generate values of " |
1369 | "scalar integer or floating-point type" ); |
1370 | } |
1371 | |
1372 | LogicalResult spirv::Deserializer::processConstantBool( |
1373 | bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) { |
1374 | if (operands.size() != 2) { |
1375 | return emitError(loc: unknownLoc, message: "Op" ) |
1376 | << (isSpec ? "Spec" : "" ) << "Constant" |
1377 | << (isTrue ? "True" : "False" ) |
1378 | << " must have type <id> and result <id>" ; |
1379 | } |
1380 | |
1381 | auto attr = opBuilder.getBoolAttr(value: isTrue); |
1382 | auto resultID = operands[1]; |
1383 | if (isSpec) { |
1384 | createSpecConstant(unknownLoc, resultID, attr); |
1385 | } else { |
1386 | // For normal constants, we just record the attribute (and its type) for |
1387 | // later materialization at use sites. |
1388 | constantMap.try_emplace(Key: resultID, Args&: attr, Args: opBuilder.getI1Type()); |
1389 | } |
1390 | |
1391 | return success(); |
1392 | } |
1393 | |
1394 | LogicalResult |
1395 | spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) { |
1396 | if (operands.size() < 2) { |
1397 | return emitError(loc: unknownLoc, |
1398 | message: "OpConstantComposite must have type <id> and result <id>" ); |
1399 | } |
1400 | if (operands.size() < 3) { |
1401 | return emitError(loc: unknownLoc, |
1402 | message: "OpConstantComposite must have at least 1 parameter" ); |
1403 | } |
1404 | |
1405 | Type resultType = getType(id: operands[0]); |
1406 | if (!resultType) { |
1407 | return emitError(loc: unknownLoc, message: "undefined result type from <id> " ) |
1408 | << operands[0]; |
1409 | } |
1410 | |
1411 | SmallVector<Attribute, 4> elements; |
1412 | elements.reserve(N: operands.size() - 2); |
1413 | for (unsigned i = 2, e = operands.size(); i < e; ++i) { |
1414 | auto elementInfo = getConstant(id: operands[i]); |
1415 | if (!elementInfo) { |
1416 | return emitError(loc: unknownLoc, message: "OpConstantComposite component <id> " ) |
1417 | << operands[i] << " must come from a normal constant" ; |
1418 | } |
1419 | elements.push_back(Elt: elementInfo->first); |
1420 | } |
1421 | |
1422 | auto resultID = operands[1]; |
1423 | if (auto vectorType = dyn_cast<VectorType>(resultType)) { |
1424 | auto attr = DenseElementsAttr::get(vectorType, elements); |
1425 | // For normal constants, we just record the attribute (and its type) for |
1426 | // later materialization at use sites. |
1427 | constantMap.try_emplace(resultID, attr, resultType); |
1428 | } else if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: resultType)) { |
1429 | auto attr = opBuilder.getArrayAttr(elements); |
1430 | constantMap.try_emplace(resultID, attr, resultType); |
1431 | } else { |
1432 | return emitError(loc: unknownLoc, message: "unsupported OpConstantComposite type: " ) |
1433 | << resultType; |
1434 | } |
1435 | |
1436 | return success(); |
1437 | } |
1438 | |
1439 | LogicalResult |
1440 | spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) { |
1441 | if (operands.size() < 2) { |
1442 | return emitError(loc: unknownLoc, |
1443 | message: "OpConstantComposite must have type <id> and result <id>" ); |
1444 | } |
1445 | if (operands.size() < 3) { |
1446 | return emitError(loc: unknownLoc, |
1447 | message: "OpConstantComposite must have at least 1 parameter" ); |
1448 | } |
1449 | |
1450 | Type resultType = getType(id: operands[0]); |
1451 | if (!resultType) { |
1452 | return emitError(loc: unknownLoc, message: "undefined result type from <id> " ) |
1453 | << operands[0]; |
1454 | } |
1455 | |
1456 | auto resultID = operands[1]; |
1457 | auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(id: resultID)); |
1458 | |
1459 | SmallVector<Attribute, 4> elements; |
1460 | elements.reserve(N: operands.size() - 2); |
1461 | for (unsigned i = 2, e = operands.size(); i < e; ++i) { |
1462 | auto elementInfo = getSpecConstant(operands[i]); |
1463 | elements.push_back(SymbolRefAttr::get(elementInfo)); |
1464 | } |
1465 | |
1466 | auto op = opBuilder.create<spirv::SpecConstantCompositeOp>( |
1467 | unknownLoc, TypeAttr::get(resultType), symName, |
1468 | opBuilder.getArrayAttr(elements)); |
1469 | specConstCompositeMap[resultID] = op; |
1470 | |
1471 | return success(); |
1472 | } |
1473 | |
1474 | LogicalResult |
1475 | spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) { |
1476 | if (operands.size() < 3) |
1477 | return emitError(loc: unknownLoc, message: "OpConstantOperation must have type <id>, " |
1478 | "result <id>, and operand opcode" ); |
1479 | |
1480 | uint32_t resultTypeID = operands[0]; |
1481 | |
1482 | if (!getType(id: resultTypeID)) |
1483 | return emitError(loc: unknownLoc, message: "undefined result type from <id> " ) |
1484 | << resultTypeID; |
1485 | |
1486 | uint32_t resultID = operands[1]; |
1487 | spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]); |
1488 | auto emplaceResult = specConstOperationMap.try_emplace( |
1489 | Key: resultID, |
1490 | Args: SpecConstOperationMaterializationInfo{ |
1491 | enclosedOpcode, resultTypeID, |
1492 | SmallVector<uint32_t>{operands.begin() + 3, operands.end()}}); |
1493 | |
1494 | if (!emplaceResult.second) |
1495 | return emitError(loc: unknownLoc, message: "value with <id>: " ) |
1496 | << resultID << " is probably defined before." ; |
1497 | |
1498 | return success(); |
1499 | } |
1500 | |
1501 | Value spirv::Deserializer::materializeSpecConstantOperation( |
1502 | uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, |
1503 | ArrayRef<uint32_t> enclosedOpOperands) { |
1504 | |
1505 | Type resultType = getType(id: resultTypeID); |
1506 | |
1507 | // Instructions wrapped by OpSpecConstantOp need an ID for their |
1508 | // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V |
1509 | // dialect wrapped op. For that purpose, a new value map is created and "fake" |
1510 | // ID in that map is assigned to the result of the enclosed instruction. Note |
1511 | // that there is no need to update this fake ID since we only need to |
1512 | // reference the created Value for the enclosed op from the spv::YieldOp |
1513 | // created later in this method (both of which are the only values in their |
1514 | // region: the SpecConstantOperation's region). If we encounter another |
1515 | // SpecConstantOperation in the module, we simply re-use the fake ID since the |
1516 | // previous Value assigned to it isn't visible in the current scope anyway. |
1517 | DenseMap<uint32_t, Value> newValueMap; |
1518 | llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap); |
1519 | constexpr uint32_t fakeID = static_cast<uint32_t>(-3); |
1520 | |
1521 | SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands; |
1522 | enclosedOpResultTypeAndOperands.push_back(Elt: resultTypeID); |
1523 | enclosedOpResultTypeAndOperands.push_back(Elt: fakeID); |
1524 | enclosedOpResultTypeAndOperands.append(in_start: enclosedOpOperands.begin(), |
1525 | in_end: enclosedOpOperands.end()); |
1526 | |
1527 | // Process enclosed instruction before creating the enclosing |
1528 | // specConstantOperation (and its region). This way, references to constants, |
1529 | // global variables, and spec constants will be materialized outside the new |
1530 | // op's region. For more info, see Deserializer::getValue's implementation. |
1531 | if (failed( |
1532 | processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands))) |
1533 | return Value(); |
1534 | |
1535 | // Since the enclosed op is emitted in the current block, split it in a |
1536 | // separate new block. |
1537 | Block *enclosedBlock = curBlock->splitBlock(splitBeforeOp: &curBlock->back()); |
1538 | |
1539 | auto loc = createFileLineColLoc(opBuilder); |
1540 | auto specConstOperationOp = |
1541 | opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType); |
1542 | |
1543 | Region &body = specConstOperationOp.getBody(); |
1544 | // Move the new block into SpecConstantOperation's body. |
1545 | body.getBlocks().splice(where: body.end(), L2&: curBlock->getParent()->getBlocks(), |
1546 | first: Region::iterator(enclosedBlock)); |
1547 | Block &block = body.back(); |
1548 | |
1549 | // RAII guard to reset the insertion point to the module's region after |
1550 | // deserializing the body of the specConstantOperation. |
1551 | OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); |
1552 | opBuilder.setInsertionPointToEnd(&block); |
1553 | |
1554 | opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0)); |
1555 | return specConstOperationOp.getResult(); |
1556 | } |
1557 | |
1558 | LogicalResult |
1559 | spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) { |
1560 | if (operands.size() != 2) { |
1561 | return emitError(loc: unknownLoc, |
1562 | message: "OpConstantNull must have type <id> and result <id>" ); |
1563 | } |
1564 | |
1565 | Type resultType = getType(id: operands[0]); |
1566 | if (!resultType) { |
1567 | return emitError(loc: unknownLoc, message: "undefined result type from <id> " ) |
1568 | << operands[0]; |
1569 | } |
1570 | |
1571 | auto resultID = operands[1]; |
1572 | if (resultType.isIntOrFloat() || isa<VectorType>(Val: resultType)) { |
1573 | auto attr = opBuilder.getZeroAttr(resultType); |
1574 | // For normal constants, we just record the attribute (and its type) for |
1575 | // later materialization at use sites. |
1576 | constantMap.try_emplace(resultID, attr, resultType); |
1577 | return success(); |
1578 | } |
1579 | |
1580 | return emitError(loc: unknownLoc, message: "unsupported OpConstantNull type: " ) |
1581 | << resultType; |
1582 | } |
1583 | |
1584 | //===----------------------------------------------------------------------===// |
1585 | // Control flow |
1586 | //===----------------------------------------------------------------------===// |
1587 | |
1588 | Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) { |
1589 | if (auto *block = getBlock(id)) { |
1590 | LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id |
1591 | << " @ " << block << "\n" ); |
1592 | return block; |
1593 | } |
1594 | |
1595 | // We don't know where this block will be placed finally (in a |
1596 | // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the |
1597 | // function for now and sort out the proper place later. |
1598 | auto *block = curFunction->addBlock(); |
1599 | LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id |
1600 | << " @ " << block << "\n" ); |
1601 | return blockMap[id] = block; |
1602 | } |
1603 | |
1604 | LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) { |
1605 | if (!curBlock) { |
1606 | return emitError(loc: unknownLoc, message: "OpBranch must appear inside a block" ); |
1607 | } |
1608 | |
1609 | if (operands.size() != 1) { |
1610 | return emitError(loc: unknownLoc, message: "OpBranch must take exactly one target label" ); |
1611 | } |
1612 | |
1613 | auto *target = getOrCreateBlock(id: operands[0]); |
1614 | auto loc = createFileLineColLoc(opBuilder); |
1615 | // The preceding instruction for the OpBranch instruction could be an |
1616 | // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have |
1617 | // the same OpLine information. |
1618 | opBuilder.create<spirv::BranchOp>(loc, target); |
1619 | |
1620 | clearDebugLine(); |
1621 | return success(); |
1622 | } |
1623 | |
1624 | LogicalResult |
1625 | spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) { |
1626 | if (!curBlock) { |
1627 | return emitError(loc: unknownLoc, |
1628 | message: "OpBranchConditional must appear inside a block" ); |
1629 | } |
1630 | |
1631 | if (operands.size() != 3 && operands.size() != 5) { |
1632 | return emitError(loc: unknownLoc, |
1633 | message: "OpBranchConditional must have condition, true label, " |
1634 | "false label, and optionally two branch weights" ); |
1635 | } |
1636 | |
1637 | auto condition = getValue(id: operands[0]); |
1638 | auto *trueBlock = getOrCreateBlock(id: operands[1]); |
1639 | auto *falseBlock = getOrCreateBlock(id: operands[2]); |
1640 | |
1641 | std::optional<std::pair<uint32_t, uint32_t>> weights; |
1642 | if (operands.size() == 5) { |
1643 | weights = std::make_pair(x: operands[3], y: operands[4]); |
1644 | } |
1645 | // The preceding instruction for the OpBranchConditional instruction could be |
1646 | // an OpSelectionMerge instruction, in this case they will have the same |
1647 | // OpLine information. |
1648 | auto loc = createFileLineColLoc(opBuilder); |
1649 | opBuilder.create<spirv::BranchConditionalOp>( |
1650 | loc, condition, trueBlock, |
1651 | /*trueArguments=*/ArrayRef<Value>(), falseBlock, |
1652 | /*falseArguments=*/ArrayRef<Value>(), weights); |
1653 | |
1654 | clearDebugLine(); |
1655 | return success(); |
1656 | } |
1657 | |
1658 | LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) { |
1659 | if (!curFunction) { |
1660 | return emitError(loc: unknownLoc, message: "OpLabel must appear inside a function" ); |
1661 | } |
1662 | |
1663 | if (operands.size() != 1) { |
1664 | return emitError(loc: unknownLoc, message: "OpLabel should only have result <id>" ); |
1665 | } |
1666 | |
1667 | auto labelID = operands[0]; |
1668 | // We may have forward declared this block. |
1669 | auto *block = getOrCreateBlock(id: labelID); |
1670 | LLVM_DEBUG(logger.startLine() |
1671 | << "[block] populating block " << block << "\n" ); |
1672 | // If we have seen this block, make sure it was just a forward declaration. |
1673 | assert(block->empty() && "re-deserialize the same block!" ); |
1674 | |
1675 | opBuilder.setInsertionPointToStart(block); |
1676 | blockMap[labelID] = curBlock = block; |
1677 | |
1678 | return success(); |
1679 | } |
1680 | |
1681 | LogicalResult |
1682 | spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) { |
1683 | if (!curBlock) { |
1684 | return emitError(loc: unknownLoc, message: "OpSelectionMerge must appear in a block" ); |
1685 | } |
1686 | |
1687 | if (operands.size() < 2) { |
1688 | return emitError( |
1689 | loc: unknownLoc, |
1690 | message: "OpSelectionMerge must specify merge target and selection control" ); |
1691 | } |
1692 | |
1693 | auto *mergeBlock = getOrCreateBlock(id: operands[0]); |
1694 | auto loc = createFileLineColLoc(opBuilder); |
1695 | auto selectionControl = operands[1]; |
1696 | |
1697 | if (!blockMergeInfo.try_emplace(Key: curBlock, Args&: loc, Args&: selectionControl, Args&: mergeBlock) |
1698 | .second) { |
1699 | return emitError( |
1700 | loc: unknownLoc, |
1701 | message: "a block cannot have more than one OpSelectionMerge instruction" ); |
1702 | } |
1703 | |
1704 | return success(); |
1705 | } |
1706 | |
1707 | LogicalResult |
1708 | spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) { |
1709 | if (!curBlock) { |
1710 | return emitError(loc: unknownLoc, message: "OpLoopMerge must appear in a block" ); |
1711 | } |
1712 | |
1713 | if (operands.size() < 3) { |
1714 | return emitError(loc: unknownLoc, message: "OpLoopMerge must specify merge target, " |
1715 | "continue target and loop control" ); |
1716 | } |
1717 | |
1718 | auto *mergeBlock = getOrCreateBlock(id: operands[0]); |
1719 | auto *continueBlock = getOrCreateBlock(id: operands[1]); |
1720 | auto loc = createFileLineColLoc(opBuilder); |
1721 | uint32_t loopControl = operands[2]; |
1722 | |
1723 | if (!blockMergeInfo |
1724 | .try_emplace(Key: curBlock, Args&: loc, Args&: loopControl, Args&: mergeBlock, Args&: continueBlock) |
1725 | .second) { |
1726 | return emitError( |
1727 | loc: unknownLoc, |
1728 | message: "a block cannot have more than one OpLoopMerge instruction" ); |
1729 | } |
1730 | |
1731 | return success(); |
1732 | } |
1733 | |
1734 | LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) { |
1735 | if (!curBlock) { |
1736 | return emitError(loc: unknownLoc, message: "OpPhi must appear in a block" ); |
1737 | } |
1738 | |
1739 | if (operands.size() < 4) { |
1740 | return emitError(loc: unknownLoc, message: "OpPhi must specify result type, result <id>, " |
1741 | "and variable-parent pairs" ); |
1742 | } |
1743 | |
1744 | // Create a block argument for this OpPhi instruction. |
1745 | Type blockArgType = getType(id: operands[0]); |
1746 | BlockArgument blockArg = curBlock->addArgument(type: blockArgType, loc: unknownLoc); |
1747 | valueMap[operands[1]] = blockArg; |
1748 | LLVM_DEBUG(logger.startLine() |
1749 | << "[phi] created block argument " << blockArg |
1750 | << " id = " << operands[1] << " of type " << blockArgType << "\n" ); |
1751 | |
1752 | // For each (value, predecessor) pair, insert the value to the predecessor's |
1753 | // blockPhiInfo entry so later we can fix the block argument there. |
1754 | for (unsigned i = 2, e = operands.size(); i < e; i += 2) { |
1755 | uint32_t value = operands[i]; |
1756 | Block *predecessor = getOrCreateBlock(id: operands[i + 1]); |
1757 | std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock}; |
1758 | blockPhiInfo[predecessorTargetPair].push_back(Elt: value); |
1759 | LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor |
1760 | << " with arg id = " << value << "\n" ); |
1761 | } |
1762 | |
1763 | return success(); |
1764 | } |
1765 | |
1766 | namespace { |
1767 | /// A class for putting all blocks in a structured selection/loop in a |
1768 | /// spirv.mlir.selection/spirv.mlir.loop op. |
1769 | class ControlFlowStructurizer { |
1770 | public: |
1771 | #ifndef NDEBUG |
1772 | ControlFlowStructurizer(Location loc, uint32_t control, |
1773 | spirv::BlockMergeInfoMap &mergeInfo, Block *, |
1774 | Block *merge, Block *cont, |
1775 | llvm::ScopedPrinter &logger) |
1776 | : location(loc), control(control), blockMergeInfo(mergeInfo), |
1777 | headerBlock(header), mergeBlock(merge), continueBlock(cont), |
1778 | logger(logger) {} |
1779 | #else |
1780 | ControlFlowStructurizer(Location loc, uint32_t control, |
1781 | spirv::BlockMergeInfoMap &mergeInfo, Block *header, |
1782 | Block *merge, Block *cont) |
1783 | : location(loc), control(control), blockMergeInfo(mergeInfo), |
1784 | headerBlock(header), mergeBlock(merge), continueBlock(cont) {} |
1785 | #endif |
1786 | |
1787 | /// Structurizes the loop at the given `headerBlock`. |
1788 | /// |
1789 | /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move |
1790 | /// all blocks in the structured loop into the spirv.mlir.loop's region. All |
1791 | /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This |
1792 | /// method will also update `mergeInfo` by remapping all blocks inside to the |
1793 | /// newly cloned ones inside structured control flow op's regions. |
1794 | LogicalResult structurize(); |
1795 | |
1796 | private: |
1797 | /// Creates a new spirv.mlir.selection op at the beginning of the |
1798 | /// `mergeBlock`. |
1799 | spirv::SelectionOp createSelectionOp(uint32_t selectionControl); |
1800 | |
1801 | /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`. |
1802 | spirv::LoopOp createLoopOp(uint32_t loopControl); |
1803 | |
1804 | /// Collects all blocks reachable from `headerBlock` except `mergeBlock`. |
1805 | void collectBlocksInConstruct(); |
1806 | |
1807 | Location location; |
1808 | uint32_t control; |
1809 | |
1810 | spirv::BlockMergeInfoMap &blockMergeInfo; |
1811 | |
1812 | Block *; |
1813 | Block *mergeBlock; |
1814 | Block *continueBlock; // nullptr for spirv.mlir.selection |
1815 | |
1816 | SetVector<Block *> constructBlocks; |
1817 | |
1818 | #ifndef NDEBUG |
1819 | /// A logger used to emit information during the deserialzation process. |
1820 | llvm::ScopedPrinter &logger; |
1821 | #endif |
1822 | }; |
1823 | } // namespace |
1824 | |
1825 | spirv::SelectionOp |
1826 | ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) { |
1827 | // Create a builder and set the insertion point to the beginning of the |
1828 | // merge block so that the newly created SelectionOp will be inserted there. |
1829 | OpBuilder builder(&mergeBlock->front()); |
1830 | |
1831 | auto control = static_cast<spirv::SelectionControl>(selectionControl); |
1832 | auto selectionOp = builder.create<spirv::SelectionOp>(location, control); |
1833 | selectionOp.addMergeBlock(builder); |
1834 | |
1835 | return selectionOp; |
1836 | } |
1837 | |
1838 | spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) { |
1839 | // Create a builder and set the insertion point to the beginning of the |
1840 | // merge block so that the newly created LoopOp will be inserted there. |
1841 | OpBuilder builder(&mergeBlock->front()); |
1842 | |
1843 | auto control = static_cast<spirv::LoopControl>(loopControl); |
1844 | auto loopOp = builder.create<spirv::LoopOp>(location, control); |
1845 | loopOp.addEntryAndMergeBlock(builder); |
1846 | |
1847 | return loopOp; |
1848 | } |
1849 | |
1850 | void ControlFlowStructurizer::collectBlocksInConstruct() { |
1851 | assert(constructBlocks.empty() && "expected empty constructBlocks" ); |
1852 | |
1853 | // Put the header block in the work list first. |
1854 | constructBlocks.insert(X: headerBlock); |
1855 | |
1856 | // For each item in the work list, add its successors excluding the merge |
1857 | // block. |
1858 | for (unsigned i = 0; i < constructBlocks.size(); ++i) { |
1859 | for (auto *successor : constructBlocks[i]->getSuccessors()) |
1860 | if (successor != mergeBlock) |
1861 | constructBlocks.insert(X: successor); |
1862 | } |
1863 | } |
1864 | |
1865 | LogicalResult ControlFlowStructurizer::structurize() { |
1866 | Operation *op = nullptr; |
1867 | bool isLoop = continueBlock != nullptr; |
1868 | if (isLoop) { |
1869 | if (auto loopOp = createLoopOp(control)) |
1870 | op = loopOp.getOperation(); |
1871 | } else { |
1872 | if (auto selectionOp = createSelectionOp(control)) |
1873 | op = selectionOp.getOperation(); |
1874 | } |
1875 | if (!op) |
1876 | return failure(); |
1877 | Region &body = op->getRegion(index: 0); |
1878 | |
1879 | IRMapping mapper; |
1880 | // All references to the old merge block should be directed to the |
1881 | // selection/loop merge block in the SelectionOp/LoopOp's region. |
1882 | mapper.map(from: mergeBlock, to: &body.back()); |
1883 | |
1884 | collectBlocksInConstruct(); |
1885 | |
1886 | // We've identified all blocks belonging to the selection/loop's region. Now |
1887 | // need to "move" them into the selection/loop. Instead of really moving the |
1888 | // blocks, in the following we copy them and remap all values and branches. |
1889 | // This is because: |
1890 | // * Inserting a block into a region requires the block not in any region |
1891 | // before. But selections/loops can nest so we can create selection/loop ops |
1892 | // in a nested manner, which means some blocks may already be in a |
1893 | // selection/loop region when to be moved again. |
1894 | // * It's much trickier to fix up the branches into and out of the loop's |
1895 | // region: we need to treat not-moved blocks and moved blocks differently: |
1896 | // Not-moved blocks jumping to the loop header block need to jump to the |
1897 | // merge point containing the new loop op but not the loop continue block's |
1898 | // back edge. Moved blocks jumping out of the loop need to jump to the |
1899 | // merge block inside the loop region but not other not-moved blocks. |
1900 | // We cannot use replaceAllUsesWith clearly and it's harder to follow the |
1901 | // logic. |
1902 | |
1903 | // Create a corresponding block in the SelectionOp/LoopOp's region for each |
1904 | // block in this loop construct. |
1905 | OpBuilder builder(body); |
1906 | for (auto *block : constructBlocks) { |
1907 | // Create a block and insert it before the selection/loop merge block in the |
1908 | // SelectionOp/LoopOp's region. |
1909 | auto *newBlock = builder.createBlock(insertBefore: &body.back()); |
1910 | mapper.map(from: block, to: newBlock); |
1911 | LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock |
1912 | << " from block " << block << "\n" ); |
1913 | if (!isFnEntryBlock(block)) { |
1914 | for (BlockArgument blockArg : block->getArguments()) { |
1915 | auto newArg = |
1916 | newBlock->addArgument(type: blockArg.getType(), loc: blockArg.getLoc()); |
1917 | mapper.map(from: blockArg, to: newArg); |
1918 | LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument " |
1919 | << blockArg << " to " << newArg << "\n" ); |
1920 | } |
1921 | } else { |
1922 | LLVM_DEBUG(logger.startLine() |
1923 | << "[cf] block " << block << " is a function entry block\n" ); |
1924 | } |
1925 | |
1926 | for (auto &op : *block) |
1927 | newBlock->push_back(op: op.clone(mapper)); |
1928 | } |
1929 | |
1930 | // Go through all ops and remap the operands. |
1931 | auto remapOperands = [&](Operation *op) { |
1932 | for (auto &operand : op->getOpOperands()) |
1933 | if (Value mappedOp = mapper.lookupOrNull(from: operand.get())) |
1934 | operand.set(mappedOp); |
1935 | for (auto &succOp : op->getBlockOperands()) |
1936 | if (Block *mappedOp = mapper.lookupOrNull(from: succOp.get())) |
1937 | succOp.set(mappedOp); |
1938 | }; |
1939 | for (auto &block : body) |
1940 | block.walk(callback&: remapOperands); |
1941 | |
1942 | // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to |
1943 | // the selection/loop construct into its region. Next we need to fix the |
1944 | // connections between this new SelectionOp/LoopOp with existing blocks. |
1945 | |
1946 | // All existing incoming branches should go to the merge block, where the |
1947 | // SelectionOp/LoopOp resides right now. |
1948 | headerBlock->replaceAllUsesWith(newValue&: mergeBlock); |
1949 | |
1950 | LLVM_DEBUG({ |
1951 | logger.startLine() << "[cf] after cloning and fixing references:\n" ; |
1952 | headerBlock->getParentOp()->print(logger.getOStream()); |
1953 | logger.startLine() << "\n" ; |
1954 | }); |
1955 | |
1956 | if (isLoop) { |
1957 | if (!mergeBlock->args_empty()) { |
1958 | return mergeBlock->getParentOp()->emitError( |
1959 | message: "OpPhi in loop merge block unsupported" ); |
1960 | } |
1961 | |
1962 | // The loop header block may have block arguments. Since now we place the |
1963 | // loop op inside the old merge block, we need to make sure the old merge |
1964 | // block has the same block argument list. |
1965 | for (BlockArgument blockArg : headerBlock->getArguments()) |
1966 | mergeBlock->addArgument(type: blockArg.getType(), loc: blockArg.getLoc()); |
1967 | |
1968 | // If the loop header block has block arguments, make sure the spirv.Branch |
1969 | // op matches. |
1970 | SmallVector<Value, 4> blockArgs; |
1971 | if (!headerBlock->args_empty()) |
1972 | blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()}; |
1973 | |
1974 | // The loop entry block should have a unconditional branch jumping to the |
1975 | // loop header block. |
1976 | builder.setInsertionPointToEnd(&body.front()); |
1977 | builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock), |
1978 | ArrayRef<Value>(blockArgs)); |
1979 | } |
1980 | |
1981 | // All the blocks cloned into the SelectionOp/LoopOp's region can now be |
1982 | // cleaned up. |
1983 | LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n" ); |
1984 | // First we need to drop all operands' references inside all blocks. This is |
1985 | // needed because we can have blocks referencing SSA values from one another. |
1986 | for (auto *block : constructBlocks) |
1987 | block->dropAllReferences(); |
1988 | |
1989 | // Check that whether some op in the to-be-erased blocks still has uses. Those |
1990 | // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's |
1991 | // region. We cannot handle such cases given that once a value is sinked into |
1992 | // the SelectionOp/LoopOp's region, there is no escape for it: |
1993 | // SelectionOp/LooOp does not support yield values right now. |
1994 | for (auto *block : constructBlocks) { |
1995 | for (Operation &op : *block) |
1996 | if (!op.use_empty()) |
1997 | return op.emitOpError( |
1998 | message: "failed control flow structurization: it has uses outside of the " |
1999 | "enclosing selection/loop construct" ); |
2000 | } |
2001 | |
2002 | // Then erase all old blocks. |
2003 | for (auto *block : constructBlocks) { |
2004 | // We've cloned all blocks belonging to this construct into the structured |
2005 | // control flow op's region. Among these blocks, some may compose another |
2006 | // selection/loop. If so, they will be recorded within blockMergeInfo. |
2007 | // We need to update the pointers there to the newly remapped ones so we can |
2008 | // continue structurizing them later. |
2009 | // TODO: The asserts in the following assumes input SPIR-V blob forms |
2010 | // correctly nested selection/loop constructs. We should relax this and |
2011 | // support error cases better. |
2012 | auto it = blockMergeInfo.find(Val: block); |
2013 | if (it != blockMergeInfo.end()) { |
2014 | // Use the original location for nested selection/loop ops. |
2015 | Location loc = it->second.loc; |
2016 | |
2017 | Block * = mapper.lookupOrNull(from: block); |
2018 | if (!newHeader) |
2019 | return emitError(loc, message: "failed control flow structurization: nested " |
2020 | "loop header block should be remapped!" ); |
2021 | |
2022 | Block *newContinue = it->second.continueBlock; |
2023 | if (newContinue) { |
2024 | newContinue = mapper.lookupOrNull(from: newContinue); |
2025 | if (!newContinue) |
2026 | return emitError(loc, message: "failed control flow structurization: nested " |
2027 | "loop continue block should be remapped!" ); |
2028 | } |
2029 | |
2030 | Block *newMerge = it->second.mergeBlock; |
2031 | if (Block *mappedTo = mapper.lookupOrNull(from: newMerge)) |
2032 | newMerge = mappedTo; |
2033 | |
2034 | // The iterator should be erased before adding a new entry into |
2035 | // blockMergeInfo to avoid iterator invalidation. |
2036 | blockMergeInfo.erase(I: it); |
2037 | blockMergeInfo.try_emplace(Key: newHeader, Args&: loc, Args&: it->second.control, Args&: newMerge, |
2038 | Args&: newContinue); |
2039 | } |
2040 | |
2041 | // The structured selection/loop's entry block does not have arguments. |
2042 | // If the function's header block is also part of the structured control |
2043 | // flow, we cannot just simply erase it because it may contain arguments |
2044 | // matching the function signature and used by the cloned blocks. |
2045 | if (isFnEntryBlock(block)) { |
2046 | LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block |
2047 | << " to only contain a spirv.Branch op\n" ); |
2048 | // Still keep the function entry block for the potential block arguments, |
2049 | // but replace all ops inside with a branch to the merge block. |
2050 | block->clear(); |
2051 | builder.setInsertionPointToEnd(block); |
2052 | builder.create<spirv::BranchOp>(location, mergeBlock); |
2053 | } else { |
2054 | LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n" ); |
2055 | block->erase(); |
2056 | } |
2057 | } |
2058 | |
2059 | LLVM_DEBUG(logger.startLine() |
2060 | << "[cf] after structurizing construct with header block " |
2061 | << headerBlock << ":\n" |
2062 | << *op << "\n" ); |
2063 | |
2064 | return success(); |
2065 | } |
2066 | |
2067 | LogicalResult spirv::Deserializer::wireUpBlockArgument() { |
2068 | LLVM_DEBUG({ |
2069 | logger.startLine() |
2070 | << "//----- [phi] start wiring up block arguments -----//\n" ; |
2071 | logger.indent(); |
2072 | }); |
2073 | |
2074 | OpBuilder::InsertionGuard guard(opBuilder); |
2075 | |
2076 | for (const auto &info : blockPhiInfo) { |
2077 | Block *block = info.first.first; |
2078 | Block *target = info.first.second; |
2079 | const BlockPhiInfo &phiInfo = info.second; |
2080 | LLVM_DEBUG({ |
2081 | logger.startLine() << "[phi] block " << block << "\n" ; |
2082 | logger.startLine() << "[phi] before creating block argument:\n" ; |
2083 | block->getParentOp()->print(logger.getOStream()); |
2084 | logger.startLine() << "\n" ; |
2085 | }); |
2086 | |
2087 | // Set insertion point to before this block's terminator early because we |
2088 | // may materialize ops via getValue() call. |
2089 | auto *op = block->getTerminator(); |
2090 | opBuilder.setInsertionPoint(op); |
2091 | |
2092 | SmallVector<Value, 4> blockArgs; |
2093 | blockArgs.reserve(N: phiInfo.size()); |
2094 | for (uint32_t valueId : phiInfo) { |
2095 | if (Value value = getValue(id: valueId)) { |
2096 | blockArgs.push_back(Elt: value); |
2097 | LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value |
2098 | << " id = " << valueId << "\n" ); |
2099 | } else { |
2100 | return emitError(loc: unknownLoc, message: "OpPhi references undefined value!" ); |
2101 | } |
2102 | } |
2103 | |
2104 | if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) { |
2105 | // Replace the previous branch op with a new one with block arguments. |
2106 | opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(), |
2107 | blockArgs); |
2108 | branchOp.erase(); |
2109 | } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) { |
2110 | assert((branchCondOp.getTrueBlock() == target || |
2111 | branchCondOp.getFalseBlock() == target) && |
2112 | "expected target to be either the true or false target" ); |
2113 | if (target == branchCondOp.getTrueTarget()) |
2114 | opBuilder.create<spirv::BranchConditionalOp>( |
2115 | branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs, |
2116 | branchCondOp.getFalseBlockArguments(), |
2117 | branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(), |
2118 | branchCondOp.getFalseTarget()); |
2119 | else |
2120 | opBuilder.create<spirv::BranchConditionalOp>( |
2121 | branchCondOp.getLoc(), branchCondOp.getCondition(), |
2122 | branchCondOp.getTrueBlockArguments(), blockArgs, |
2123 | branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(), |
2124 | branchCondOp.getFalseBlock()); |
2125 | |
2126 | branchCondOp.erase(); |
2127 | } else { |
2128 | return emitError(loc: unknownLoc, message: "unimplemented terminator for Phi creation" ); |
2129 | } |
2130 | |
2131 | LLVM_DEBUG({ |
2132 | logger.startLine() << "[phi] after creating block argument:\n" ; |
2133 | block->getParentOp()->print(logger.getOStream()); |
2134 | logger.startLine() << "\n" ; |
2135 | }); |
2136 | } |
2137 | blockPhiInfo.clear(); |
2138 | |
2139 | LLVM_DEBUG({ |
2140 | logger.unindent(); |
2141 | logger.startLine() |
2142 | << "//--- [phi] completed wiring up block arguments ---//\n" ; |
2143 | }); |
2144 | return success(); |
2145 | } |
2146 | |
2147 | LogicalResult spirv::Deserializer::structurizeControlFlow() { |
2148 | LLVM_DEBUG({ |
2149 | logger.startLine() |
2150 | << "//----- [cf] start structurizing control flow -----//\n" ; |
2151 | logger.indent(); |
2152 | }); |
2153 | |
2154 | while (!blockMergeInfo.empty()) { |
2155 | Block * = blockMergeInfo.begin()->first; |
2156 | BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second; |
2157 | |
2158 | LLVM_DEBUG({ |
2159 | logger.startLine() << "[cf] header block " << headerBlock << ":\n" ; |
2160 | headerBlock->print(logger.getOStream()); |
2161 | logger.startLine() << "\n" ; |
2162 | }); |
2163 | |
2164 | auto *mergeBlock = mergeInfo.mergeBlock; |
2165 | assert(mergeBlock && "merge block cannot be nullptr" ); |
2166 | if (!mergeBlock->args_empty()) |
2167 | return emitError(loc: unknownLoc, message: "OpPhi in loop merge block unimplemented" ); |
2168 | LLVM_DEBUG({ |
2169 | logger.startLine() << "[cf] merge block " << mergeBlock << ":\n" ; |
2170 | mergeBlock->print(logger.getOStream()); |
2171 | logger.startLine() << "\n" ; |
2172 | }); |
2173 | |
2174 | auto *continueBlock = mergeInfo.continueBlock; |
2175 | LLVM_DEBUG(if (continueBlock) { |
2176 | logger.startLine() << "[cf] continue block " << continueBlock << ":\n" ; |
2177 | continueBlock->print(logger.getOStream()); |
2178 | logger.startLine() << "\n" ; |
2179 | }); |
2180 | // Erase this case before calling into structurizer, who will update |
2181 | // blockMergeInfo. |
2182 | blockMergeInfo.erase(I: blockMergeInfo.begin()); |
2183 | ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control, |
2184 | blockMergeInfo, headerBlock, |
2185 | mergeBlock, continueBlock |
2186 | #ifndef NDEBUG |
2187 | , |
2188 | logger |
2189 | #endif |
2190 | ); |
2191 | if (failed(result: structurizer.structurize())) |
2192 | return failure(); |
2193 | } |
2194 | |
2195 | LLVM_DEBUG({ |
2196 | logger.unindent(); |
2197 | logger.startLine() |
2198 | << "//--- [cf] completed structurizing control flow ---//\n" ; |
2199 | }); |
2200 | return success(); |
2201 | } |
2202 | |
2203 | //===----------------------------------------------------------------------===// |
2204 | // Debug |
2205 | //===----------------------------------------------------------------------===// |
2206 | |
2207 | Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) { |
2208 | if (!debugLine) |
2209 | return unknownLoc; |
2210 | |
2211 | auto fileName = debugInfoMap.lookup(Val: debugLine->fileID).str(); |
2212 | if (fileName.empty()) |
2213 | fileName = "<unknown>" ; |
2214 | return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line, |
2215 | debugLine->column); |
2216 | } |
2217 | |
2218 | LogicalResult |
2219 | spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) { |
2220 | // According to SPIR-V spec: |
2221 | // "This location information applies to the instructions physically |
2222 | // following this instruction, up to the first occurrence of any of the |
2223 | // following: the next end of block, the next OpLine instruction, or the next |
2224 | // OpNoLine instruction." |
2225 | if (operands.size() != 3) |
2226 | return emitError(loc: unknownLoc, message: "OpLine must have 3 operands" ); |
2227 | debugLine = DebugLine{.fileID: operands[0], .line: operands[1], .column: operands[2]}; |
2228 | return success(); |
2229 | } |
2230 | |
2231 | void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; } |
2232 | |
2233 | LogicalResult |
2234 | spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) { |
2235 | if (operands.size() < 2) |
2236 | return emitError(loc: unknownLoc, message: "OpString needs at least 2 operands" ); |
2237 | |
2238 | if (!debugInfoMap.lookup(Val: operands[0]).empty()) |
2239 | return emitError(loc: unknownLoc, |
2240 | message: "duplicate debug string found for result <id> " ) |
2241 | << operands[0]; |
2242 | |
2243 | unsigned wordIndex = 1; |
2244 | StringRef debugString = decodeStringLiteral(words: operands, wordIndex); |
2245 | if (wordIndex != operands.size()) |
2246 | return emitError(loc: unknownLoc, |
2247 | message: "unexpected trailing words in OpString instruction" ); |
2248 | |
2249 | debugInfoMap[operands[0]] = debugString; |
2250 | return success(); |
2251 | } |
2252 | |